{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.text import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reduce original dataset to questions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = Config().data_path()/'giga-fren'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You only need to execute the setup cells once, uncomment to run. The dataset can be downloaded [here](https://s3.amazonaws.com/fast-ai-nlp/giga-fren.tgz)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#! wget https://s3.amazonaws.com/fast-ai-nlp/giga-fren.tgz -P {path}\n",
    "#! tar xf {path}/giga-fren.tgz -C {path} \n",
    "\n",
    "# with open(path/'giga-fren.release2.fixed.fr') as f:\n",
    "#    fr = f.read().split('\\n')\n",
    "\n",
    "# with open(path/'giga-fren.release2.fixed.en') as f:\n",
    "#    en = f.read().split('\\n')\n",
    "\n",
    "# re_eq = re.compile('^(Wh[^?.!]+\\?)')\n",
    "# re_fq = re.compile('^([^?.!]+\\?)')\n",
    "# en_fname = path/'giga-fren.release2.fixed.en'\n",
    "# fr_fname = path/'giga-fren.release2.fixed.fr'\n",
    "\n",
    "# lines = ((re_eq.search(eq), re_fq.search(fq)) \n",
    "#         for eq, fq in zip(open(en_fname, encoding='utf-8'), open(fr_fname, encoding='utf-8')))\n",
    "# qs = [(e.group(), f.group()) for e,f in lines if e and f]\n",
    "\n",
    "# qs = [(q1,q2) for q1,q2 in qs]\n",
    "# df = pd.DataFrame({'fr': [q[1] for q in qs], 'en': [q[0] for q in qs]}, columns = ['en', 'fr'])\n",
    "# df.to_csv(path/'questions_easy.csv', index=False)\n",
    "\n",
    "# del en, fr, lines, qs, df # free RAM or restart the nb "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### fastText pre-trained word vectors https://fasttext.cc/docs/en/crawl-vectors.html\n",
    "#! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.bin.gz -P {path}\n",
    "#! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz -P {path}\n",
    "#! gzip -d {path}/cc.fr.300.bin.gz \n",
    "#! gzip -d {path}/cc.en.300.bin.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[PosixPath('/home/stas/.fastai/data/giga-fren/models'),\n",
       " PosixPath('/home/stas/.fastai/data/giga-fren/giga-fren.release2.fixed.en'),\n",
       " PosixPath('/home/stas/.fastai/data/giga-fren/giga-fren.release2.fixed.fr'),\n",
       " PosixPath('/home/stas/.fastai/data/giga-fren/data_save.pkl'),\n",
       " PosixPath('/home/stas/.fastai/data/giga-fren/cc.en.300.bin'),\n",
       " PosixPath('/home/stas/.fastai/data/giga-fren/questions_easy.csv'),\n",
       " PosixPath('/home/stas/.fastai/data/giga-fren/cc.fr.300.bin')]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path.ls()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Put them in a DataBunch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our questions look like this now:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>en</th>\n",
       "      <th>fr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>What is light ?</td>\n",
       "      <td>Qu’est-ce que la lumière?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Who are we?</td>\n",
       "      <td>Où sommes-nous?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Where did we come from?</td>\n",
       "      <td>D'où venons-nous?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>What would we do without it?</td>\n",
       "      <td>Que ferions-nous sans elle ?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>What is the absolute location (latitude and lo...</td>\n",
       "      <td>Quelle sont les coordonnées (latitude et longi...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                  en  \\\n",
       "0                                    What is light ?   \n",
       "1                                        Who are we?   \n",
       "2                            Where did we come from?   \n",
       "3                       What would we do without it?   \n",
       "4  What is the absolute location (latitude and lo...   \n",
       "\n",
       "                                                  fr  \n",
       "0                          Qu’est-ce que la lumière?  \n",
       "1                                    Où sommes-nous?  \n",
       "2                                  D'où venons-nous?  \n",
       "3                       Que ferions-nous sans elle ?  \n",
       "4  Quelle sont les coordonnées (latitude et longi...  "
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(path/'questions_easy.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make it simple, we lowercase everything."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['en'] = df['en'].apply(lambda x:x.lower())\n",
    "df['fr'] = df['fr'].apply(lambda x:x.lower())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first thing is that we will need to collate inputs and targets in a batch: they have different lengths so we need to add padding to make the sequence length the same;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seq2seq_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]:\n",
    "    \"Function that collect samples and adds padding. Flips token order if needed\"\n",
    "    samples = to_data(samples)\n",
    "    max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples])\n",
    "    res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx\n",
    "    res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx\n",
    "    if backwards: pad_first = not pad_first\n",
    "    for i,s in enumerate(samples):\n",
    "        if pad_first: \n",
    "            res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n",
    "        else:         \n",
    "            res_x[i,:len(s[0]):],res_y[i,:len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n",
    "    if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)\n",
    "    return res_x,res_y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we create a special `DataBunch` that uses this collate function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqDataBunch(TextDataBunch):\n",
    "    \"Create a `TextDataBunch` suitable for training an RNN classifier.\"\n",
    "    @classmethod\n",
    "    def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,\n",
    "               pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:\n",
    "        \"Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`\"\n",
    "        datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n",
    "        val_bs = ifnone(val_bs, bs)\n",
    "        collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)\n",
    "        train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)\n",
    "        train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)\n",
    "        dataloaders = [train_dl]\n",
    "        for ds in datasets[1:]:\n",
    "            lengths = [len(t) for t in ds.x.items]\n",
    "            sampler = SortSampler(ds.x, key=lengths.__getitem__)\n",
    "            dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))\n",
    "        return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And a subclass of `TextList` that will use this `DataBunch` class in the call `.databunch` and will use `TextList` to label (since our targets are other texts)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqTextList(TextList):\n",
    "    _bunch = Seq2SeqDataBunch\n",
    "    _label_cls = TextList"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Thats all we need to use the data block API!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "src = Seq2SeqTextList.from_df(df, path = path, cols='fr').split_by_rand_pct().label_from_df(cols='en', label_cls=TextList)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "28.0"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.percentile([len(o) for o in src.train.x.items] + [len(o) for o in src.valid.x.items], 90)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "23.0"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.percentile([len(o) for o in src.train.y.items] + [len(o) for o in src.valid.y.items], 90)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We remove the items where one of the target is more than 30 tokens long."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "src = src.filter_by_func(lambda x,y: len(x) > 30 or len(y) > 30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "48352"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(src.train) + len(src.valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = src.databunch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = load_data(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>text</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>xxbos à quoi cela peut - il bien servir , alors que l’on xxunk toujours combien il y aura de ces unités et dans quels domaines elles seront présentes ?</td>\n",
       "      <td>xxbos what use was this , when it was still not known how many such units there would be and in what fields ?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>xxbos quels autres fabricants de dispositifs médicaux avez - vous évalués et certifiés selon la norme iso xxunk : 2003 et le marquage ce ( le cas échéant ) ?</td>\n",
       "      <td>xxbos what medical xxunk companies has your organization audited and certified to iso xxunk and xxunk mark ( where applicable ) ?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>xxbos quel est le lien entre le fep , les fonds structurels , le fonds de cohésion et le xxunk ( fonds européen agricole pour le développement rural ) ?</td>\n",
       "      <td>xxbos what is the link between the eff , structural funds , cohesion fund and xxunk ?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>xxbos quel a été le rôle d'agriculture et agroalimentaire canada ( aac ) dans le processus de révision de la norme nationale sur l'agriculture biologique qui date de 1999 ?</td>\n",
       "      <td>xxbos what was the role of agriculture and agri - food canada ( aafc ) in the initiative to revise the 1999 national standard for organic agriculture ?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>xxbos lesquelles des activités de r - d ci - après votre établissement a - t - il menées au cours des trois derniers exercices se terminant en 2003 ?</td>\n",
       "      <td>xxbos which of the following r&amp;d activities were carried out at your establishment over the last three fiscal years ending in 2003 ?</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pretrained embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To install fastText:\n",
    "```\n",
    "$ git clone https://github.com/facebookresearch/fastText.git\n",
    "$ cd fastText\n",
    "$ pip install .\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Installation: https://github.com/facebookresearch/fastText#building-fasttext-for-python\n",
    "import fastText as ft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fr_vecs = ft.load_model(str((path/'cc.fr.300.bin')))\n",
    "en_vecs = ft.load_model(str((path/'cc.en.300.bin')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create an embedding module with the pretrained vectors and random data for the missing parts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_emb(vecs, itos, em_sz=300, mult=1.):\n",
    "    emb = nn.Embedding(len(itos), em_sz, padding_idx=1)\n",
    "    wgts = emb.weight.data\n",
    "    vec_dic = {w:vecs.get_word_vector(w) for w in vecs.get_words()}\n",
    "    miss = []\n",
    "    for i,w in enumerate(itos):\n",
    "        try: wgts[i] = tensor(vec_dic[w])\n",
    "        except: miss.append(w)\n",
    "    return emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_enc = create_emb(fr_vecs, data.x.vocab.itos)\n",
    "emb_dec = create_emb(en_vecs, data.y.vocab.itos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(emb_enc, path/'models'/'fr_emb.pth')\n",
    "torch.save(emb_dec, path/'models'/'en_emb.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Free some RAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del fr_vecs\n",
    "del en_vecs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### QRNN seq2seq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our model we use QRNNs at its base (you can use GRUs or LSTMs by adapting a little bit). Using QRNNs require you have properly installed cuda (a version that matches your PyTorch install). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/stas/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/utils/cpp_extension.py:166: UserWarning: \n",
      "\n",
      "                               !! WARNING !!\n",
      "\n",
      "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n",
      "Your compiler (c++) is not compatible with the compiler Pytorch was\n",
      "built with for this platform, which is g++ on linux. Please\n",
      "use g++ to to compile your extension. Alternatively, you may\n",
      "compile PyTorch from source using c++, and then you can also use\n",
      "c++ to compile your extension.\n",
      "\n",
      "See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help\n",
      "with compiling PyTorch from source.\n",
      "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n",
      "\n",
      "                              !! WARNING !!\n",
      "\n",
      "  platform=sys.platform))\n",
      "/home/stas/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/utils/cpp_extension.py:166: UserWarning: \n",
      "\n",
      "                               !! WARNING !!\n",
      "\n",
      "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n",
      "Your compiler (c++) is not compatible with the compiler Pytorch was\n",
      "built with for this platform, which is g++ on linux. Please\n",
      "use g++ to to compile your extension. Alternatively, you may\n",
      "compile PyTorch from source using c++, and then you can also use\n",
      "c++ to compile your extension.\n",
      "\n",
      "See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help\n",
      "with compiling PyTorch from source.\n",
      "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n",
      "\n",
      "                              !! WARNING !!\n",
      "\n",
      "  platform=sys.platform))\n"
     ]
    }
   ],
   "source": [
    "from fastai.text.models.qrnn import QRNN, QRNNLayer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model in itself consists in an encoder and a decoder\n",
    "\n",
    "![Seq2seq model](images/seq2seq.png)\n",
    "\n",
    "The encoder is a (quasi) recurrent neural net and we feed it our input sentence, producing an output (that we discard for now) and a hidden state. That hidden state is then given to the decoder (an other RNN) which uses it in conjunction with the outputs it predicts to get produce the translation. We loop until the decoder produces a padding token (or at 30 iterations to make sure it's not an infinite loop at the beginning of training). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqQRNN(nn.Module):\n",
    "    def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n",
    "                 p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n",
    "        super().__init__()\n",
    "        self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n",
    "        self.emb_enc = emb_enc\n",
    "        self.emb_enc_drop = nn.Dropout(p_inp)\n",
    "        self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc)\n",
    "        self.out_enc = nn.Linear(n_hid, emb_enc.weight.size(1), bias=False)\n",
    "        self.hid_dp  = nn.Dropout(p_hid)\n",
    "        self.emb_dec = emb_dec\n",
    "        self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n",
    "        self.out_drop = nn.Dropout(p_out)\n",
    "        self.out = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0))\n",
    "        self.out.weight.data = self.emb_dec.weight.data\n",
    "        \n",
    "    def forward(self, inp):\n",
    "        bs,sl = inp.size()\n",
    "        self.encoder.reset()\n",
    "        self.decoder.reset()\n",
    "        hid = self.initHidden(bs)\n",
    "        emb = self.emb_enc_drop(self.emb_enc(inp))\n",
    "        enc_out, hid = self.encoder(emb, hid)\n",
    "        hid = self.out_enc(self.hid_dp(hid))\n",
    "\n",
    "        dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n",
    "        outs = []\n",
    "        for i in range(self.max_len):\n",
    "            emb = self.emb_dec(dec_inp).unsqueeze(1)\n",
    "            out, hid = self.decoder(emb, hid)\n",
    "            out = self.out(self.out_drop(out[:,0]))\n",
    "            outs.append(out)\n",
    "            dec_inp = out.max(1)[1]\n",
    "            if (dec_inp==self.pad_idx).all(): break\n",
    "        return torch.stack(outs, dim=1)\n",
    "    \n",
    "    def initHidden(self, bs): return one_param(self).new_zeros(self.n_layers, bs, self.n_hid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loss function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The loss pads output and target so that they are of the same size before using the usual flattened version of cross entropy. We do the same for accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seq2seq_loss(out, targ, pad_idx=1):\n",
    "    bs,targ_len = targ.size()\n",
    "    _,out_len,vs = out.size()\n",
    "    if targ_len>out_len: out  = F.pad(out,  (0,0,0,targ_len-out_len,0,0), value=pad_idx)\n",
    "    if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)\n",
    "    return CrossEntropyFlat()(out, targ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seq2seq_acc(out, targ, pad_idx=1):\n",
    "    bs,targ_len = targ.size()\n",
    "    _,out_len,vs = out.size()\n",
    "    if targ_len>out_len: out  = F.pad(out,  (0,0,0,targ_len-out_len,0,0), value=pad_idx)\n",
    "    if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)\n",
    "    out = out.argmax(2)\n",
    "    return (out==targ).float().mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Bleu metric (see dedicated notebook)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In translation, the metric usually used is BLEU, see the corresponding notebook for the details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NGram():\n",
    "    def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n\n",
    "    def __eq__(self, other):\n",
    "        if len(self.ngram) != len(other.ngram): return False\n",
    "        return np.all(np.array(self.ngram) == np.array(other.ngram))\n",
    "    def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_grams(x, n, max_n=5000):\n",
    "    return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_correct_ngrams(pred, targ, n, max_n=5000):\n",
    "    pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n)\n",
    "    pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)\n",
    "    return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CorpusBLEU(Callback):\n",
    "    def __init__(self, vocab_sz):\n",
    "        self.vocab_sz = vocab_sz\n",
    "        self.name = 'bleu'\n",
    "    \n",
    "    def on_epoch_begin(self, **kwargs):\n",
    "        self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4\n",
    "    \n",
    "    def on_batch_end(self, last_output, last_target, **kwargs):\n",
    "        last_output = last_output.argmax(dim=-1)\n",
    "        for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):\n",
    "            self.pred_len += len(pred)\n",
    "            self.targ_len += len(targ)\n",
    "            for i in range(4):\n",
    "                c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)\n",
    "                self.corrects[i] += c\n",
    "                self.counts[i]   += t\n",
    "    \n",
    "    def on_epoch_end(self, last_metrics, **kwargs):\n",
    "        precs = [c/t for c,t in zip(self.corrects,self.counts)]\n",
    "        len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1\n",
    "        bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25)\n",
    "        return add_metrics(last_metrics, bleu)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load our pretrained embeddings to create the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_enc = torch.load(path/'models'/'fr_emb.pth')\n",
    "emb_dec = torch.load(path/'models'/'en_emb.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)\n",
    "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
     ]
    }
   ],
   "source": [
    "learn.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8XHW9//HXZyZr0zRdkkL3FlqWArZAhAKCyCaibCIKwv0VN67XBQT1/vD6u+J1QWS5iqBiRRavggqoIHKlyFaVtYW2FAqWLrRpS5uma9oks31+f8xJOg2TTEgza97PR89jzvnOd+Z8vp0kn/me7/ecY+6OiIhIb0L5DkBERAqfkoWIiGSkZCEiIhkpWYiISEZKFiIikpGShYiIZKRkISIiGSlZiIhIRkoWIiKSUVm+Axgo9fX1Pnny5HyHISJSVBYsWLDJ3Rsy1SuZZDF58mTmz5+f7zBERIqKmb3Zl3o6DCUiIhkpWYiISEZKFiIikpGShYiIZKRkISIiGSlZiIhIRkoWIiKSkZKFiEgRu39BE3c/tzrr+1GyEBEpYn9cuJb7FqzJ+n6ULEREilhHLEF5OPt/ypUsRESKWDSeoKKsiJOFmd1uZhvNbElK2flm9oqZJcyssZfXnm5mr5vZG2Z2VbZiFBEpdpFYgooi71ncCZzerWwJ8GFgXk8vMrMw8GPgA8B04EIzm56lGEVEilo0XuSHodx9HrC5W9lSd389w0uPAt5w9xXuHgF+A5ydpTBFRIpaNO7FfRhqL4wDUof2m4KytzGzS81svpnNb25uzklwIiKFJDKIB7gtTZmnq+juc9y90d0bGxoy3rtDRKTkRIp9gHsvNAETUrbHA+vyFIuISEFLDnCn+449sAoxWbwATDOzKWZWAVwAPJjnmEREClIpTJ29B3gGONDMmszsU2Z2rpk1AccAfzazR4K6Y83sYQB3jwFfAB4BlgK/c/dXshWniEgxy9WYRdbuwe3uF/bw1B/S1F0HnJGy/TDwcJZCExEpCYmEE0v4oB3gFhGRPogmEgDFfRhKRESyKxILkoV6FiIi0pNoPHlWgXoWIiLSo86ehcYsRESkR9G4xixERCSDjq6exeA8KU9ERPqgq2ehw1AiItITHYYSEZGMNMAtIiIZRdSzEBGRTNSzEBGRjDpPyqtUz0JERHqinoWIiGTUORtK51mIiEiPNMAtIiIZ6aqzIiKSkU7KExGRjDTALSIiGalnISIiGXX2LMpCmg0lIiI9iMSdinAIsyJOFmZ2u5ltNLMlKWUjzexRM1sWPI7o4bVxM1sYLA9mK0YRkWIWjSdycggKstuzuBM4vVvZVcBj7j4NeCzYTqfN3WcGy1lZjFFEpGhFYomcnJAHWUwW7j4P2Nyt+GzgrmD9LuCcbO1fRKTUlUrPIp193H09QPA4uod6VWY238yeNTMlFBGRNJI9i9z8GS/LyV7euYnuvs7M9gMeN7OX3X1590pmdilwKcDEiRNzHaOISF5FSrhnscHMxgAEjxvTVXL3dcHjCuBJ4PAe6s1x90Z3b2xoaMhOxCIiBSoSS+TkUh+Q+2TxIDA7WJ8NPNC9gpmNMLPKYL0eOA54NWcRiogUiWg8d4ehsjl19h7gGeBAM2sys08B1wKnmtky4NRgGzNrNLPbgpceDMw3s0XAE8C17q5kISLSTTTuOTsMlbUxC3e/sIenTk5Tdz7w6WD9aeCwbMUlIlIqSmLqrIiIZFdygDuck30pWYiIFKnkALd6FiIi0otSPilPREQGSKQUZkOJiEh2RXN4BreShYhIkYrkcOqskoWISJGKxOIlewa3iIgMkFyelKdkISJSpJID3Jo6KyIiPYgnnHjCqQjrpDwREelBNJ4AoLxMPQsREelBJEgWGuAWEZEeRWNBstAAt4iI9KSzZ6GT8kREpEfRmAM6DCUiIr2IxOMAlOswlIiI9CSinoWIiGTSOXW2QlNnRUSkJxrgFhGRjLqmzipZiIhITzq6zuBWshARkR6UTM/CzG43s41mtiSlbKSZPWpmy4LHET28dnZQZ5mZzc5WjCIixarrch8l0LO4Ezi9W9lVwGPuPg14LNjeg5mNBK4GjgaOAq7uKamIiAxW0VK5NpS7zwM2dys+G7grWL8LOCfNS98PPOrum919C/Aob086IiKDWucZ3KU6ZrGPu68HCB5Hp6kzDliTst0UlL2NmV1qZvPNbH5zc/OABysiUqi6BrgH8c2P0rXc01V09znu3ujujQ0NDVkOS0SkcHQOcFeW6M2PNpjZGIDgcWOaOk3AhJTt8cC6HMQmIlI0IiV+86MHgc7ZTbOBB9LUeQQ4zcxGBAPbpwVlIiISKKWps/cAzwAHmlmTmX0KuBY41cyWAacG25hZo5ndBuDum4FvAy8Ey7eCMhERCUTiCcwgHMpNz6IsW2/s7hf28NTJaerOBz6dsn07cHuWQhMRKXqReIKKcAiz0jwMJSIiAyAa85wdggIlCxGRohSJx3N2jgUoWYiIFCX1LEREJKNIPJGzabOgZCEiUpQ6B7hzRclCRKQIRWKJnN0lD5QsRESKUjSeoFID3CIi0ptoXD0LERHJQIehREQko0jcc3aXPFCyEBEpSupZiIhIRhrgFhGRjJI9C52UJyIivYjGExqzEBGR3mnqrIiIZNShAW4REcmkIAe4zWx/M6sM1k80s8vMbHh2QxMRkZ4U6tTZ+4G4mU0FfgFMAe7OWlQiItKjeMJJOAU5wJ1w9xhwLvBDd78CGJO9sEREpCeRWAKgIHsWUTO7EJgNPBSUlWcnJBER6U0knkwWhdiz+ARwDPBdd19pZlOAX/V3p2Z2uZktMbNXzOxLaZ4/0cy2mdnCYPlGf/clIlJqop3JIocn5ZX1pZK7vwpcBmBmI4Bad7+2Pzs0s0OBzwBHARHgL2b2Z3df1q3q39z9Q/3Zh4hIKes8DFVwPQsze9LMhpnZSGARcIeZ/Xc/93kw8Ky77wrGQZ4iORYiIiJ90NmzKMQxizp33w58GLjD3Y8ETunnPpcAJ5jZKDMbApwBTEhT7xgzW2Rm/2tmh/RzXyIiJScfA9x9OgwFlJnZGOCjwNf3ZofuvtTMvg88CrSS7KnEulV7EZjk7q1mdgbwR2Ba9/cys0uBSwEmTpy4N2GJiBSNQh7g/hbwCLDc3V8ws/2A7mMMfebuv3D3I9z9BGBz9/dy9+3u3hqsPwyUm1l9mveZ4+6N7t7Y0NDQ33BERIpK15hFofUs3P1e4N6U7RXAef3dqZmNdveNZjaR5KGtY7o9vy+wwd3dzI4imdRa+rs/EZFSEo07kNueRZ+ShZmNB24GjgMc+Dtwubs39XO/95vZKCAKfN7dt5jZZwHc/VbgI8C/mVkMaAMucHfv575EREpKPga4+zpmcQfJy3ucH2xfHJSd2p+duvvxacpuTVm/BbilP+8tIlLqCnbqLNDg7ne4eyxY7gQ0SCAikgeRrp5F4d0pb5OZXWxm4WC5GI0hiIjkRT4GuPu6p0+SnDb7FrCe5JjCJ7IVlIiI9CxaqFNn3X21u5/l7g3uPtrdzyE5i0lERHKskK86m86VAxaFiIj0WcH2LHqQu5EVERHpEgnOsyiWnoXOexARyYPOw1C5vAd3r+dZmNkO0icFA6qzEpGIiPSq4E7Kc/faXAUiIiJ9E4klCBmEQ4V3noWIiBSIaDyR08FtULIQESk6HbFETg9BgZKFiEjRicYTOR3cBiULEZGiE42rZyEiIhlEYhqzEBGRDKJxV89CRER6pwFuERHJSFNnRUQko0gsQUUOb3wEShYiIkVHPQsREclIU2dFRCSjjlgip7dUBSULEZGiE40nKB8Mh6HM7HIzW2Jmr5jZl9I8b2b2IzN7w8wWm9kR+YhTRKQQReKDoGdhZocCnwGOAmYAHzKzad2qfQCYFiyXAj/NaZAiIgUsGvOcJ4te72eRJQcDz7r7LgAzewo4F7gupc7ZwC/d3YFnzWy4mY1x9/UDHUwsnmDlpp2EQkbYjHAouYTMCBlgEDLDHRwn+AeAByuOB8+D++57RZkZBphBwpPPue9+nQUz30KhZL2QGWa93K/WwDLczbbzPS1l/6nlnXF1Xgs/2c7ONiefE5HCFoknKC/L7e9qPpLFEuC7ZjYKaAPOAOZ3qzMOWJOy3RSUDXiy2NoW5dQfzBvoty1aIUtNWhYkS7qSSmeSCYdClIV2J9fOZFMWChEOGWXhZFlnnfJwsn5ZOER52CgLhSgP1ivKkusVZSEqwiGqysNUlScfq8vDyceK5Hp1eZiayjA1lWUMqQhTU1FGKIc3gBEpBNFYgopwOKf7zHmycPelZvZ94FGgFVgExLpVS/fb/7bbu5rZpSQPUzFx4sR+xTO0soybLzychDvxxO7FgYR7V4/Akjvs6il0fmff85v87j+w+J49js5v7aHgPTp7IV2PQb1EsN79C76n9GjSVugs73rv3b0c71bFgUTCk212J5Fw4gm61hNdce3uTSX/f+j6f4olnHgiQSwRvN6T7xlLJJLvFTwXiyfrt8ZixOJONN5ZniAabEfiCaKx4DH+zm/tXlMRpraqnKFVZdRWlVFXXc7w6nLqqsupG1LBqJoKRg2tYGRNBQ1DK9mnrorayjL1oqRoDZaeBe7+C+AXAGZ2DcmeQ6omYELK9nhgXZr3mQPMAWhsbHznf2WAqvIwZ84Y25+XShYkEk5HLEF7NE57LE57NEFbJE5bNE57NM6uSJxdkRitHTF2dsRo7YjT2h6jtSNKa0eM7W0xWlojrGjeydZdEba3d/8eklRTEWafuir2HZZcOtdH11Yysiu5VDK8ulw9Fyko7k4knqByEIxZYGaj3X2jmU0EPgwc063Kg8AXzOw3wNHAtmyMV0jhCYUsecipYmC62LF4gi27omzeGaFlZwebWiNs2NbO+m3tvLW9jbe2tfPcys1s2N5OLPH27xvhkFE/tIKG2kpG11bRMLSS+toK6odWpizJ7TolFsmBeCJ59CHXJ+XlJVkA9wdjFlHg8+6+xcw+C+DutwIPkxzLeAPYBXwiT3FKkSsLh2ioraShthKo7bFeIuG07IywcUc7m3dGksmlNZlgmnd0sHFHB29ta2fJ2m207IwQ7yGx7Dusiin1NUwaNSR4rGFK/RAmjBxCZVlujzFLaYrEEwA5P88iX4ehjk9TdmvKugOfz2lQMqiFQpaSVHqXSDhb26I07+igpbWDTTsjbNrRwabWDtZubWNVyy4eWryebW3R3e9vMHZ4NZNGDWH88CFMGFnN+BFDGD8i+Ti6tlK9EumTaCz5RWUwTJ0VKWqhkDGyJjlg3ltvZcvOCKtadrKqZScrN+1i1aadrNmyi8de28im1o496paHjbHDq5kwItkLmRgsU+prmFJfM2CH5aT4DaqehchgMKKmghE1FRw+ccTbnmuLxFm7dRdNW9pSll2s2dLGI6+8xeadka66ZjBueDX7Nwxl+thhvGtcHYeNr2Pc8GrN6BqEOpPFoBjgFhnsqivCTB1dy9TR6XsmrR0xVrfsYuWmnSxvbmV5cyvLNrTy83krugbiR9ZU0DhpBLP2G8Ws/UZx0L61OpQ1CERjnT2LQTB1VkR6N7SyjOljhzF97LA9ytujcV5/awcvr93GojVbeX7VZua+ugGAuupyjp9Wz/sOHM17D2ygfmjm8RcpPp09i5I/KU9E+q+qPMyMCcOZMWE4F8+aBMC6rW08t7KFf7zRwpOvN/PQ4uQs8xnj6zhzxljOmjGW0cOq8hm2DKBIZ88ix3fKU7IQKXJjh1dz7uHjOffw8SQSzqvrt/PEaxt5dOkGvvPnpVzz8FKOm1rPOTPH8cF3jaGqXIPlxUwD3CKy10Ih49BxdRw6ro4vnjyN5c2tPPDSWv6wcC1fvncR1zy8lItnTeJfjpmkw1RFqnPMItcD3Lr5kUgJ279hKFeediDzvvo+7v700cycMJybHlvGsdc+zlX3L6Zpy658hyjvUOf109SzEJEBZ2YcO7WeY6fWs7y5lTv+sZJ75zfx+5fWcsmxk/n8iVOpG1Ke7zClDyLxOJD7k/LUsxAZZPZvGMp3zjmMJ75yImfNGMvP/7aC4697nDnzlncNnkrhigRncOf62lBKFiKD1Njh1dxw/gwevux4Dp84gmsefo3Tb5rH35dtyndo0ouuqbOD4R7cIlI4Dh4zjLs+eRS3X9JILO5c/Ivn+PzdL7J+W1u+Q5M0Oge4dRhKRPLipIP2Ye4VJ3DlqQfw11c3cNp/z+Pp5eplFJrdU2dze56FkoWIdKkqD3PZydN49Ir3MmZ4FZfc/gJ/WaJbyRSSaFw9CxEpEBNHDeF3/3oMh44bxud+/SL3PL863yFJoOsMbo1ZiEghGD6kgl99+mhOOKCBr/3+ZX78xBv5DklIvTaUkoWIFIghFWX8/P80cs7MsVz/yOvc/NiyfIc06OnmRyJSkMrDIW786ExCZtz46D8Jh43PnTg132ENWpF4nLKQ5fxy9EoWIpJROGRcf/4M4u5c95fXKQ+F+MwJ++U7rEEpGvecn5AHShYi0kfhkHHj+TOIJZzvPryUcMj45Hum5DusQScSS+T88uSgZCEi70BZOMQPPzaTWDzBt//8KgfuW8txU+vzHdagEoknqCjL/WXm8zLAbWZXmNkrZrbEzO4xs6puz19iZs1mtjBYPp2POEXk7crDIX7wsZnsV1/DFb9dSEtrR75DGlSisQQVeehZ5DxZmNk44DKg0d0PBcLABWmq/tbdZwbLbTkNUkR6NaSijJsvPIKtu6J89b7FuHu+Qxo0kj2L3H/Pz9fU2TKg2szKgCHAujzFISL9NH3sMP7jjIN4/LWN3Pn0qnyHM2jsisQHR7Jw97XADcBqYD2wzd3npql6npktNrP7zGxCToMUkT6ZfexkTjl4NN97+DWWrN2W73AGhaXrtzNtdG3O95uPw1AjgLOBKcBYoMbMLu5W7U/AZHd/F/BX4K4e3utSM5tvZvObm5uzGbaIpGFmXPeRGYyoKeey37xEezSe75BK2qbWDpq2tDFjQl3O952Pw1CnACvdvdndo8DvgWNTK7h7i7t3jpr9HDgy3Ru5+xx3b3T3xoaGhqwGLSLpjayp4PqPzGBF805uflxneGfT4qatAMwYPzzn+85HslgNzDKzIWZmwMnA0tQKZjYmZfOs7s+LSGE54YAGzjtiPD97agWvrtue73BK1sI12wgZHDZ+EPQs3P054D7gReDlIIY5ZvYtMzsrqHZZMLV2EcmZU5fkOk4ReWf+80MHM3xIOVf9fjGxuG7Pmg0L12zlgH1qGVKR+1Pk8jIbyt2vdveD3P1Qd/8Xd+9w92+4+4PB819z90PcfYa7v8/dX8tHnCLSd8OHVPDNsw5hcdM2zY7KAndn0ZqtzJyQ+0NQoKvOisgA+uBhYzjl4NHcMPd1Vrfsync4JeXNll1sa4sqWYhI8TMzvn3OoZSFQvz7/YuIJ3Sy3kBZuCYY3FayEJFSMKaumqvPnM6zKzZz/SOv5zuckrFwzVaqy8NMGz00L/tXshCRAXd+4wQuOnoitz61nIcW6wINA2FR01YOG1dHWR4uTw5KFiKSJVefeQhHThrBV+9dzGtvaTrt3ojEEryybjszJ+bnEBQoWYhIllSUhfjpRUdQW1XGpb9cwNZdkXyHVLRee2s7kVgiLyfjdVKyEJGsGT2sip9efCTrt7Xx4Z8+zff+dylPvL6R1o5YvkMrKou6BrdzfzJeJ938SESy6shJI7jl40dw299WcPvfV/Kzp1YQDhnjhldTVR6iqjxMVVmYWCJBa0eMHe0xWttjjBxaweEThnP4xBHMnDCcafsMzcvJaIVg4Zpt1A+tZNzw6rzFMDj/50Ukp95/yL68/5B9aYvEWfDmFp5d0ULTll20RxO0x+K0R+NUl4UZXVvF0KoyhlaWsX5bG08vb+GPC3cPkNdVlzOmropxw6v5txP3p3HyyDy2KncWrtnCzAl1JK+QlB9KFiKSM9UVYd4zrZ73TOvbrVjdnfXb2lm4ZiurWnayfms767e18dLqrXz53kU8duV78zY7KFe2t0dZ3ryTc2aOy2scShYiUrDMjLHDqxnb7fDL3Ffe4tL/WcAfXlrL+Y2lfbubl5uS9wnJ18l4nUo7JYtISTp1+j5MHzOMW554o+QvWth15nYeZ0KBkoWIFCEz4/JTpvFmyy4eWFjaJ/0teHML+zXUUDekPK9xKFmISFE6Lehd3Pz4spLtXcTiCZ5fuZlj9huV71CULESkOJkZl508jVUtu3hwUWn2Ll5eu43WjhjH7K9kISLSb6dN34eDxwzjlsdLc+zi6eUtAMxSz0JEpP9CIePyk6eyYtNOrnn4NZp3dOQ7pAH17IoWDtynlvqhlfkORclCRIrbadP35YPvGsPt/1jJsdc+xuW/eYkFb27GvbjvpdERi/PCqs0FcQgKlCxEpMiFQsaPP34Ef73yvVx09CQeX7qR8376DDfMLe57aSxas432aELJQkRkIE0dPZRvnnUIz/7HyZx3xHh+/MRyHn9tQ77D6renl2/CDGZNUbIQERlwNZVlfPfcQzl4zDCu/N0i1m5ty3dI/fLM8hYOGTss7+dXdFKyEJGSU1Ue5icXHUEs7nzx7heJFtlMqfZonJdWb+XY/ft2Da1cyEuyMLMrzOwVM1tiZveYWVW35yvN7Ldm9oaZPWdmk/MRp4gUryn1NXzvw4fx4uqte9wL3N3Z0R5lydpt/GnROm766zK++eArrNm8K4/R7mnBm1uIxBMFcTJep5xfSNDMxgGXAdPdvc3MfgdcANyZUu1TwBZ3n2pmFwDfBz6W61hFpLidOWMsz61sYc68FfxlyVvsaI+yvT1GPLHnTKmykPHcys384XPHUlUezmmMOztirN68i4PHDOsqe2Z5C+GQ8e4phXMJ9nxddbYMqDazKDAE6H765dnAN4P1+4BbzMy82OfCiUjO/b8PTqcsFGLLrgjDqsoZVl3GsKpyJowcwpT6GiaPquHZFS184s4X+MYDS7juIzNyGt/VD77CfQua+PKpB/CFk6ZiZjy9fBPvGl/H0MrCuTB4ziNx97VmdgOwGmgD5rr73G7VxgFrgvoxM9sGjAI25TRYESl6VeVhvnnWIb3Wed9Bo/niSVO5+fE3aJw0ko++OzeXPd+8M8KDi9ZRP7SSGx/9J8ubW/nGmYewuGkb//re/XISQ1/lfMzCzEaQ7DlMAcYCNWZ2cfdqaV76tl6FmV1qZvPNbH5zc/PABysig8aXTjmA46aO4j8fWMIr67blZJ+/m7+GSCzBrz99NF857QD+uHAdH/rR34glnGP2K5zBbcjPAPcpwEp3b3b3KPB74NhudZqACQBmVgbUAZu7v5G7z3H3RndvbGhoyHLYIlLKwiHjpgsOZ8SQCj736xf7femQltYO7n5uNQ8sXMvTb2zinxt2sLMj9rZ68YTzq2ff5OgpIzlw31q+cNI0fnLREWzeFaEiHOLISSP2tkkDKh8HxFYDs8xsCMnDUCcD87vVeRCYDTwDfAR4XOMVIpJt9UMr+fFFh3Phz5/jAzf9jRs/OoP3HtD3L6L/3LCDT975Ak1b9jy3o7aqjPs+eywH7lvbVfbUPzfStKWNr33g4K6yMw4bw/4NQ9mwvZ3qitwOtGeS856Fuz9HctD6ReDlIIY5ZvYtMzsrqPYLYJSZvQFcCVyV6zhFZHA6ctJIHvzCcYysKWf27c/znYdepSMWz/i6J1/fyId/8jQdsQS/vXQWf73yBO75zCxuumAmVeVhPvfrBXv0MH75zJuMrq3ktEP22eN9Dty3lhPeQYLKFSuVL+yNjY0+f373DoqISP+0R+N87+Gl3PXMm0wfM4zvn/cuDhtf97Z67s5dT6/iWw+9ykH7DuO22Y1vu2f408s3cfFtz3H2zHH890dnsHrzLk684UkuO2kaV5x6QK6alJaZLXD3xkz1CmdelohIAakqD/NfZx/KCQc08H/vf5mzfvx3Pn7URL5y2oGMqKkgkXDmvrqBn81bzkurt3Lq9H344cdmUpNmuuux+9dz+ckH8IO//pNZ+41kefNOQmZceNTEPLSsf9SzEBHJYHt7lB8+uoy7nllFbVUZF7x7InNfeYsVm3YyYWQ1l56wPxcdNZFQKN1EzqR4wpl9+/O8sGozFWUhjp9Wz08uOjJ3jehBX3sWujaUiEgGw6rK+caZ0/nzZe/hgNG13PrUcoZUhrnl44fzxJdP5F9mTeo1UUByttUPPjaTYdXl7GiPcfGsSTmKfmDoMJSISB8dtO8wfvuvs3hrezv7DqvCrPcE0V1DbSW3z34385Y1F9R1n/pCyUJE5B0wM8bUVWeu2IPDxtelHSgvdDoMJSIiGSlZiIhIRkoWIiKSkZKFiIhkpGQhIiIZKVmIiEhGShYiIpKRkoWIiGRUMteGMrNm4M00T9UB3W971b0sdTvdempZPf27vWu6OPpaZyDakLre3zb0FmNf6vQWc6bt7p9FobQhXVmhfBa9Pd/fz6KQf57Slel3O7NJ7p75mujuXtILMCdTWep2uvVuZfMHKo6+1hmINnRrT7/aMNDteCfb3T+LQmlDIX8WvT3f38+ikH+e+vNZ6He778tgOAz1pz6U/SnDerr3GIg4+lpnINrQ1xgyGch2vJNtfRZ9i6Wvz/f3syjkn6d0ZfrdHiAlcxgqV8xsvvfhcr6FTG0oHKXQjlJoA5RGO7LZhsHQsxhoc/IdwABQGwpHKbSjFNoApdGOrLVBPQsREclIPQsREclo0CYLM7vdzDaa2ZJ+vPZIM3vZzN4wsx9Zyh1QzOyLZva6mb1iZtcNbNRpYxnwdpjZN81srZktDJYzBj7yPeLIymcRPP8VM3Mzqx+4iHuMJRufxbfNbHHwOcw1s7EDH/kecWSjDdeb2WtBO/5gZsMHPvI94shGG84PfqcTZpa1cY29ib2H95ttZsuCZXZKea+/N2lla5pVoS/ACcARwJJ+vPZ54BjAgP8FPhCUvw/4K1AZbI8u0nZ8E/hKMX8WwXMTgEdInn9TX4ztAIal1LkMuLUI23AaUBasfx/4fhG24WDgQOBJoLHQYg/imtytbCSwIngcEayP6K2dvS2Dtmfh7vOAzamsoaX0AAAGCUlEQVRlZra/mf3FzBaY2d/M7KDurzOzMSR/gZ/x5P/6L4Fzgqf/DbjW3TuCfWzMbiuy1o6cymIbfgD8O5CTgblstMPdt6dUrSHLbclSG+a6eyyo+iwwvgjbsNTdX89m3HsTew/eDzzq7pvdfQvwKHB6f3/3B22y6MEc4IvufiTwFeAnaeqMA5pStpuCMoADgOPN7Dkze8rM3p3VaHu2t+0A+EJw2OB2MxuRvVB7tFdtMLOzgLXuvijbgWaw15+FmX3XzNYAFwHfyGKsPRmIn6dOnyT5TTbXBrINudaX2NMZB6xJ2e5sT7/aqXtwB8xsKHAscG/K4bvKdFXTlHV+2ysj2d2bBbwb+J2Z7Rdk75wYoHb8FPh2sP1t4EaSv+Q5sbdtMLMhwNdJHv7ImwH6LHD3rwNfN7OvAV8Arh7gUHs0UG0I3uvrQAz49UDGmMlAtiHXeovdzD4BXB6UTQUeNrMIsNLdz6Xn9vSrnUoWu4WAre4+M7XQzMLAgmDzQZJ/SFO70eOBdcF6E/D7IDk8b2YJktdqac5m4N3sdTvcfUPK634OPJTNgNPY2zbsD0wBFgW/YOOBF83sKHd/K8uxpxqIn6lUdwN/JofJggFqQzC4+iHg5Fx+eQoM9OeQS2ljB3D3O4A7AMzsSeASd1+VUqUJODFlezzJsY0m+tPObA3UFMMCTCZlIAl4Gjg/WDdgRg+ve4Fk76FzcOiMoPyzwLeC9QNIdgGtCNsxJqXOFcBviq0N3eqsIgcD3Fn6LKal1PkicF8RtuF04FWgIRefQTZ/nsjyAHd/Y6fnAe6VJI92jAjWR/alnWnjytWHV2gLcA+wHoiSzLSfIvlt9C/AouCH+xs9vLYRWAIsB25h98mNFcCvgudeBE4q0nb8D/AysJjkN64xxdaGbnVWkZvZUNn4LO4PyheTvP7PuCJswxskvzgtDJZsz+jKRhvODd6rA9gAPFJIsZMmWQTlnwz+/98APvFOfm+6LzqDW0REMtJsKBERyUjJQkREMlKyEBGRjJQsREQkIyULERHJSMlCSpqZteZ4f7eZ2fQBeq+4Ja82u8TM/pTpaq1mNtzMPjcQ+xbpTlNnpaSZWau7Dx3A9yvz3RfFy6rU2M3sLuCf7v7dXupPBh5y90NzEZ8MLupZyKBjZg1mdr+ZvRAsxwXlR5nZ02b2UvB4YFB+iZnda2Z/Auaa2Ylm9qSZ3WfJ+zT8uvN+AEF5Y7DeGlwEcJGZPWtm+wTl+wfbL5jZt/rY+3mG3RdJHGpmj5nZi5a8J8HZQZ1rgf2D3sj1Qd2vBvtZbGb/NYD/jTLIKFnIYHQT8AN3fzdwHnBbUP4acIK7H07y6q7XpLzmGGC2u58UbB8OfAmYDuwHHJdmPzXAs+4+A5gHfCZl/zcF+894TZ7gGkYnkzybHqAdONfdjyB5D5Ubg2R1FbDc3We6+1fN7DRgGnAUMBM40sxOyLQ/kXR0IUEZjE4BpqdcxXOYmdUCdcBdZjaN5FU4y1Ne86i7p95n4Hl3bwIws4Ukr+fz9277ibD7IowLgFOD9WPYff+Au4EbeoizOuW9F5C8HwEkr+dzTfCHP0Gyx7FPmtefFiwvBdtDSSaPeT3sT6RHShYyGIWAY9y9LbXQzG4GnnD3c4Pj/0+mPL2z23t0pKzHSf+7FPXdg4I91elNm7vPNLM6kknn88CPSN7XogE40t2jZrYKqErzegO+5+4/e4f7FXkbHYaSwWguyftCAGBmnZd/rgPWBuuXZHH/z5I8/AVwQabK7r6N5C1Vv2Jm5STj3BgkivcBk4KqO4DalJc+AnwyuCcCZjbOzEYPUBtkkFGykFI3xMyaUpYrSf7hbQwGfV8leWl5gOuA75nZP4BwFmP6EnClmT0PjAG2ZXqBu79E8qqjF5C8eVCjmc0n2ct4LajTAvwjmGp7vbvPJXmY6xkzexm4jz2TiUifaeqsSI4Fd/Jrc3c3swuAC9397EyvE8knjVmI5N6RwC3BDKat5PCWtSL9pZ6FiIhkpDELERHJSMlCREQyUrIQEZGMlCxERCQjJQsREclIyUJERDL6/wC4QGktJS2xAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>seq2seq_acc</th>\n",
       "      <th>bleu</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>6.272254</td>\n",
       "      <td>6.547584</td>\n",
       "      <td>0.175653</td>\n",
       "      <td>0.084508</td>\n",
       "      <td>00:35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>5.475595</td>\n",
       "      <td>5.798847</td>\n",
       "      <td>0.237578</td>\n",
       "      <td>0.177244</td>\n",
       "      <td>00:34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>4.998140</td>\n",
       "      <td>4.741757</td>\n",
       "      <td>0.342352</td>\n",
       "      <td>0.250401</td>\n",
       "      <td>00:36</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>4.769568</td>\n",
       "      <td>4.965292</td>\n",
       "      <td>0.316322</td>\n",
       "      <td>0.226495</td>\n",
       "      <td>00:38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>4.218278</td>\n",
       "      <td>4.942849</td>\n",
       "      <td>0.316456</td>\n",
       "      <td>0.239042</td>\n",
       "      <td>00:37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>3.686281</td>\n",
       "      <td>4.311011</td>\n",
       "      <td>0.379345</td>\n",
       "      <td>0.282809</td>\n",
       "      <td>00:39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>3.294988</td>\n",
       "      <td>4.044959</td>\n",
       "      <td>0.409902</td>\n",
       "      <td>0.317913</td>\n",
       "      <td>00:41</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>2.959656</td>\n",
       "      <td>3.956887</td>\n",
       "      <td>0.420079</td>\n",
       "      <td>0.321248</td>\n",
       "      <td>00:42</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(8, 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So how good is our model? Let's see a few predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_predictions(learn, ds_type=DatasetType.Valid):\n",
    "    learn.model.eval()\n",
    "    inputs, targets, outputs = [],[],[]\n",
    "    with torch.no_grad():\n",
    "        for xb,yb in progress_bar(learn.dl(ds_type)):\n",
    "            out = learn.model(xb)\n",
    "            for x,y,z in zip(xb,yb,out):\n",
    "                inputs.append(learn.data.train_ds.x.reconstruct(x))\n",
    "                targets.append(learn.data.train_ds.y.reconstruct(y))\n",
    "                outputs.append(learn.data.train_ds.y.reconstruct(z.argmax(1)))\n",
    "    return inputs, targets, outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='152' class='' max='152', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [152/152 00:17<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "inputs, targets, outputs = get_predictions(learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos pour quelle raison demandez - vous aux émetteurs des renseignements qui n'ont pas à être fournis sur les reçus papier remis aux contribuables ?,\n",
       " Text xxbos why are your requiring xxunk to provide information that is not required to be on the paper receipts given to clients ?,\n",
       " Text xxbos why would you you to to to to to to the the the the the the ? ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[700], targets[700], outputs[700]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos quels facteurs sont responsables des différences de concentrations des contaminants présents dans les poissons dans les cours d’eau et les lacs du nord ?,\n",
       " Text xxbos what factors are responsible for the differences in the level of contaminants found fish in northern rivers and lakes ?,\n",
       " Text xxbos what are the differences between the in the the the the the the ? ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[701], targets[701], outputs[701]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos quel est l'impact sur la recherche en amont du brevetage accru dans les sciences du vivant ?,\n",
       " Text xxbos what is the impact on upstream research of increased patenting in the life sciences ?,\n",
       " Text xxbos what is the impact of on on on on on on on on ? ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[2513], targets[2513], outputs[2513]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos quels changements devrait - on apporter aux processus de réglementation fédéraux et provinciaux ?,\n",
       " Text xxbos what changes to federal and provincial regulatory processes would be required ?,\n",
       " Text xxbos what changes will be be to the the the the the public ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[4000], targets[4000], outputs[4000]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's usually beginning well, but falls into easy word at the end of the question."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Teacher forcing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One way to help training is to help the decoder by feeding it the real targets instead of its predictions (if it starts with wrong words, it's very unlikely to give us the right translation). We do that all the time at the beginning, then progressively reduce the amount of teacher forcing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TeacherForcing(LearnerCallback):\n",
    "    \n",
    "    def __init__(self, learn, end_epoch):\n",
    "        super().__init__(learn)\n",
    "        self.end_epoch = end_epoch\n",
    "    \n",
    "    def on_batch_begin(self, last_input, last_target, train, **kwargs):\n",
    "        if train: return {'last_input': [last_input, last_target]}\n",
    "    \n",
    "    def on_epoch_begin(self, epoch, **kwargs):\n",
    "        self.learn.model.pr_force = 1 - 0.5 * epoch/self.end_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqQRNN(nn.Module):\n",
    "    def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n",
    "                 p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n",
    "        super().__init__()\n",
    "        self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n",
    "        self.emb_enc = emb_enc\n",
    "        self.emb_enc_drop = nn.Dropout(p_inp)\n",
    "        self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc)\n",
    "        self.out_enc = nn.Linear(n_hid, emb_enc.weight.size(1), bias=False)\n",
    "        self.hid_dp  = nn.Dropout(p_hid)\n",
    "        self.emb_dec = emb_dec\n",
    "        self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n",
    "        self.out_drop = nn.Dropout(p_out)\n",
    "        self.out = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0))\n",
    "        self.out.weight.data = self.emb_dec.weight.data\n",
    "        self.pr_force = 0.\n",
    "        \n",
    "    def forward(self, inp, targ=None):\n",
    "        bs,sl = inp.size()\n",
    "        hid = self.initHidden(bs)\n",
    "        emb = self.emb_enc_drop(self.emb_enc(inp))\n",
    "        enc_out, hid = self.encoder(emb, hid)\n",
    "        hid = self.out_enc(self.hid_dp(hid))\n",
    "\n",
    "        dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n",
    "        res = []\n",
    "        for i in range(self.max_len):\n",
    "            emb = self.emb_dec(dec_inp).unsqueeze(1)\n",
    "            outp, hid = self.decoder(emb, hid)\n",
    "            outp = self.out(self.out_drop(outp[:,0]))\n",
    "            res.append(outp)\n",
    "            dec_inp = outp.data.max(1)[1]\n",
    "            if (dec_inp==self.pad_idx).all(): break\n",
    "            if (targ is not None) and (random.random()<self.pr_force):\n",
    "                if i>=targ.shape[1]: break\n",
    "                dec_inp = targ[:,i]\n",
    "        return torch.stack(res, dim=1)\n",
    "    \n",
    "    def initHidden(self, bs): return one_param(self).new_zeros(self.n_layers, bs, self.n_hid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_enc = torch.load(path/'models'/'fr_emb.pth')\n",
    "emb_dec = torch.load(path/'models'/'en_emb.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)\n",
    "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))],\n",
    "                callback_fns=partial(TeacherForcing, end_epoch=8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>seq2seq_acc</th>\n",
       "      <th>bleu</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.335030</td>\n",
       "      <td>4.213064</td>\n",
       "      <td>0.543526</td>\n",
       "      <td>0.311808</td>\n",
       "      <td>00:50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.240968</td>\n",
       "      <td>4.949047</td>\n",
       "      <td>0.414702</td>\n",
       "      <td>0.356721</td>\n",
       "      <td>00:46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>2.030350</td>\n",
       "      <td>5.073238</td>\n",
       "      <td>0.391867</td>\n",
       "      <td>0.354593</td>\n",
       "      <td>00:46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>2.117243</td>\n",
       "      <td>4.553541</td>\n",
       "      <td>0.430130</td>\n",
       "      <td>0.382721</td>\n",
       "      <td>00:45</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.999398</td>\n",
       "      <td>3.816537</td>\n",
       "      <td>0.479980</td>\n",
       "      <td>0.395980</td>\n",
       "      <td>00:46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>2.051997</td>\n",
       "      <td>4.174997</td>\n",
       "      <td>0.430543</td>\n",
       "      <td>0.373515</td>\n",
       "      <td>00:44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>1.926257</td>\n",
       "      <td>4.096586</td>\n",
       "      <td>0.433852</td>\n",
       "      <td>0.376887</td>\n",
       "      <td>00:44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>1.931791</td>\n",
       "      <td>4.038434</td>\n",
       "      <td>0.435708</td>\n",
       "      <td>0.376441</td>\n",
       "      <td>00:44</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(8, 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='152' class='' max='152', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [152/152 00:16<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "inputs, targets, outputs = get_predictions(learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos pour quelle raison demandez - vous aux émetteurs des renseignements qui n'ont pas à être fournis sur les reçus papier remis aux contribuables ?,\n",
       " Text xxbos why are your requiring xxunk to provide information that is not required to be on the paper receipts given to clients ?,\n",
       " Text xxbos why should you not use the cra to the cra ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[700],targets[700],outputs[700]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos quel est l'impact sur la recherche en amont du brevetage accru dans les sciences du vivant ?,\n",
       " Text xxbos what is the impact on upstream research of increased patenting in the life sciences ?,\n",
       " Text xxbos what is the impact of the on the xxunk of the xxunk of the xxunk ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[2513], targets[2513], outputs[2513]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos quels changements devrait - on apporter aux processus de réglementation fédéraux et provinciaux ?,\n",
       " Text xxbos what changes to federal and provincial regulatory processes would be required ?,\n",
       " Text xxbos what changes should be made to the regulatory process and the regulatory framework ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[4000], targets[4000], outputs[4000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#get_bleu(learn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bidir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A second things that might help is to use a bidirectional model for the encoder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqQRNN(nn.Module):\n",
    "    def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n",
    "                 p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n",
    "        super().__init__()\n",
    "        self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n",
    "        self.emb_enc = emb_enc\n",
    "        self.emb_enc_drop = nn.Dropout(p_inp)\n",
    "        self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc, bidirectional=True)\n",
    "        self.out_enc = nn.Linear(2*n_hid, emb_enc.weight.size(1), bias=False)\n",
    "        self.hid_dp  = nn.Dropout(p_hid)\n",
    "        self.emb_dec = emb_dec\n",
    "        self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n",
    "        self.out_drop = nn.Dropout(p_out)\n",
    "        self.out = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0))\n",
    "        self.out.weight.data = self.emb_dec.weight.data\n",
    "        self.pr_force = 0.\n",
    "        \n",
    "    def forward(self, inp, targ=None):\n",
    "        bs,sl = inp.size()\n",
    "        hid = self.initHidden(bs)\n",
    "        emb = self.emb_enc_drop(self.emb_enc(inp))\n",
    "        enc_out, hid = self.encoder(emb, hid)\n",
    "        \n",
    "        hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()\n",
    "        hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))\n",
    "\n",
    "        dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n",
    "        res = []\n",
    "        for i in range(self.max_len):\n",
    "            emb = self.emb_dec(dec_inp).unsqueeze(1)\n",
    "            outp, hid = self.decoder(emb, hid)\n",
    "            outp = self.out(self.out_drop(outp[:,0]))\n",
    "            res.append(outp)\n",
    "            dec_inp = outp.data.max(1)[1]\n",
    "            if (dec_inp==self.pad_idx).all(): break\n",
    "            if (targ is not None) and (random.random()<self.pr_force):\n",
    "                if i>=targ.shape[1]: break\n",
    "                dec_inp = targ[:,i]\n",
    "        return torch.stack(res, dim=1)\n",
    "    \n",
    "    def initHidden(self, bs): return one_param(self).new_zeros(2*self.n_layers, bs, self.n_hid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_enc = torch.load(path/'models'/'fr_emb.pth')\n",
    "emb_dec = torch.load(path/'models'/'en_emb.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)\n",
    "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))],\n",
    "                callback_fns=partial(TeacherForcing, end_epoch=8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
     ]
    }
   ],
   "source": [
    "learn.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XuUXFWZ9/HvU9X3S9JJunO/cQ0wjInQMEQERQmLcRwYVFjwjooKMs6Ivuo48zqLdzmOLhRFx8vLjGMExAvgiMgoOgiB4aJyTZBAyEUMkKRD0t1J0+l7dVfX8/5Rp5Oi6e50kjp16nT9PmvVqqpdp2o/mwr91N77nL3N3RERkdKViDoAERGJlhKBiEiJUyIQESlxSgQiIiVOiUBEpMQpEYiIlDglAhGREhdaIjCzm82szcw2jPHap83MzawxrPpFRGRywuwR3AKcP7rQzBYBq4DtIdYtIiKTVBbWB7v7I2a2dIyXvg78I/DzyX5WY2OjL1061keJiMh41q1bt8fdmw52XGiJYCxmdgGw093Xm9nBjr0KuApg8eLFrF27tgARiohMHWa2bTLHFWyy2MxqgGuAz07meHdf7e7N7t7c1HTQhCYiIoepkGcNHQMcBaw3s5eBhcDTZja3gDGIiMgoBRsacvfngNkjz4Nk0OzuewoVg4iIvF6Yp4/eDjwGLDOzFjO7Iqy6RETk8IV51tBlB3l9aVh1i4jI5OnKYhGREqdEICJS4pQIRESKUFv3ANffu5kX23tCr0uJQESkCL3Q2sO/PbiV3V0DodelRCAiUoS27e0DYMms2tDrUiIQESlC2zv6KE8ac6dVhV6XEoGISBHa3tHLwhk1JBMTr8uWD0oEIiJFaHtHH4tn1hSkLiUCEZEitH2vEoGISMnq7BukayDNkllKBCIiJWl7R/aMoUXqEYiIlKaRRKChIRGREjVyDYESgYhIidrR0UdjXQW1lYXZMkaJQESkyGwr4BlDoEQgIlJ0CnkNASgRiIgUlcF0hl37+pUIRERK1c7OfjIOiwuw2NwIJQIRkSJS6FNHQYlARKSojCSCQl1VDEoEIiJFZfveXirLEjTVVRasztASgZndbGZtZrYhp+wLZvasmT1jZveZ2fyw6hcRiaPtHX0smllDogDLT48Is0dwC3D+qLLr3f0N7r4C+CXw2RDrFxGJnW17+1hSwPkBCDERuPsjQMeosq6cp7WAh1W/iEjcuDs7gh5BIRXm+uUcZnYt8H5gH3BOoesXESlWHb2D9A4OF3SiGCKYLHb3a9x9EXArcPV4x5nZVWa21szWtre3Fy5AEZGIbIvg1FGI9qyh24B3j/eiu69292Z3b25qaipgWCIi0dhRConAzI7LeXoBsLmQ9YuIFLPtewu7Ic2I0OYIzOx24K1Ao5m1AP8MvMPMlgEZYBvwkbDqFxGJm20dfcyZVklVebKg9YaWCNz9sjGKbwqrPhGRuNve0ceSmYVbY2iEriwWESkSUZw6CkoEIiJFYWBomN1dAwWfKAYlAhGRotDyaj/uhV1sboQSgYhIEdje0QsU/owhUCIQESkKz7bswwyOadJksYhISVqzsZVTFs+goaai4HUrEYiIRGxnZz/Pv9LFqpPmRFK/EoGISMTu39gKoEQgIlKq1mxs5eimWo5pqoukfiUCEZEI7esf4vEX93LeSXMji0GJQEQkQg9taSOd8ciGhUCJQEQkUms2ttJYV8kbFzVEFoMSgYhIRAbTGR7e0s65J84u6Gb1oykRiIhE5PEX99KdSkc6LARKBCIikVmzsZXq8iRnHtsYaRxKBCIiEXB31mxs5ezjGwu+Ec1oSgQiIhF4buc+dncNsCrC00ZHKBGIiETggU1tJAzedsLsqENRIhARicLT21/lxHnTmFlb+EXmRlMiEBEpsEzGWb+jk+URXjuQS4lARKTAXt7bS9dAmhULlQhERErSsy37AKZ+j8DMbjazNjPbkFN2vZltNrNnzewuMyuO/woiIgX0zI5OaiqSHDs7mtVGRwuzR3ALcP6osjXAye7+BuAPwD+FWL+ISFFa39LJyQumk4xwWYlcoSUCd38E6BhVdp+7p4OnjwMLw6pfRKQYDaYzPP9KFyuKZFgIop0j+BBwT4T1i4gU3Jbd3QymMywvkoliiCgRmNk1QBq4dYJjrjKztWa2tr29vXDBiYiE6JmWTgCWL5oecSQHFDwRmNnlwDuBv3Z3H+84d1/t7s3u3tzU1FS4AEVEQrR+RyezaitY0FAddSj7lRWyMjM7H/g/wFvcva+QdYuIFINnW7IXkpkVx0QxhHv66O3AY8AyM2sxsyuAG4B6YI2ZPWNm/xFW/SIixaYnleaFtp6imh+AEHsE7n7ZGMU3hVWfiEixe65lH+7FNT8AurJYRKRg1o9MFBdZj0CJQESkQNbv6GTxzBpmFMGKo7mUCERECqSYVhzNpUQgIlIAbd0DvLJvgOULi2t+AJQIREQK4tkd2RVHi2lpiRFKBCIiBbC+pZNkwviT+eoRiIiUpOd27uO42XVUVySjDuV1lAhERApgV+cAi2fWRB3GmJQIREQKoLV7gDnTqqIOY0xKBCIiIRsYGqazb4g50yqjDmVMSgQiIiFr704BMFs9AhGR0rS7awBAQ0MiIqWqNUgEc5UIRERKU2tXdmhIcwQiIiWqrWuAirIE06vLow5lTEoEIiIha+0aYM60yqLalSyXEoGISMhau1LMqS/O+QFQIhARCV22R6BEICJSspQIRERKWE8qTe/gcNGeMQRKBCIioWot8ovJQIlARCRUI4lgdin2CMzsZjNrM7MNOWUXm9nzZpYxs+aw6hYRKRZt+y8mK80ewS3A+aPKNgDvAh4JsV4RkaIRh6GhsrA+2N0fMbOlo8o2AUV7UYWISL7t7hqgtiJJXWVof26PWNHOEZjZVWa21szWtre3Rx2OiMhhaetKMWd68fYGoIgTgbuvdvdmd29uamqKOhwRkcPS2jVQ1FcVQxEnAhGRqSC7RWXxnjEESgQiIqFx9+w6Q0U8UQzhnj56O/AYsMzMWszsCjO7yMxagJXAr8zs3rDqFxGJ2r7+IQbTmaLdonJEmGcNXTbOS3eFVaeISDE5sEWlhoZERErSyM5kxbpF5YhJJQIzO8bMKoPHbzWzj5tZQ7ihiYjEWxwuJoPJ9wjuBIbN7FjgJuAo4LbQohIRmQLagkTQVD81hoYy7p4GLgK+4e6fBOaFF5aISPy1dqVoqCmnqjwZdSgTmmwiGDKzy4DLgV8GZcW5C7OISJGIw8VkMPlE8EGyp3xe6+4vmdlRwI/CC0tEJP5auwaKevnpEZM6fdTdNwIfBzCzGUC9u18XZmAiInHX2pXi+Dn1UYdxUJM9a+ghM5tmZjOB9cD3zOxfww1NRCS+hjNOe0/xX1UMkx8amu7uXWT3Evieu58KnBteWCIi8ba3N8Vwxov+YjKYfCIoM7N5wCUcmCwWEZFxjOxMVuzLS8DkE8HngXuBre7+lJkdDbwQXlgiIvEWl4vJYPKTxXcAd+Q8fxF4d1hBiYjEXVzWGYLJTxYvNLO7gs3oW83sTjNbGHZwIiJx1dqVwgwa66ZIIgC+B/wCmA8sAO4OykREZAxtXQM01lVSniz+tT0nG2GTu3/P3dPB7RZA+0eKiIyjtav4dyYbMdlEsMfM3mtmyeD2XmBvmIGJiMRZa1cqFstLwOQTwYfInjq6G9gFvIfsshMiIjKG9p5U0a86OmJSicDdt7v7Be7e5O6z3f2vyF5cJiIiY+gZSFNfFdomkHl1JLMYn8pbFCIiU8hwxukfGqa2cuonAstbFCIiU0jvYBqAuhJIBJ63KEREppCegWwiiEuPYMIozaybsf/gG1AdSkQiIjHXm4pXIpiwR+Du9e4+bYxbvbsfLIncHFyJvCGnbKaZrTGzF4L7GflqiIhIsegJEkH9VEgER+gW4PxRZZ8BHnD344AHguciIlNKb2oYmCI9giPh7o8AHaOKLwS+Hzz+PvBXYdUvIhKVnv1DQ8W9af2IQi+CMcfddwEE97PHO9DMrjKztWa2tr29vWABiogcqZFEUApnDYXK3Ve7e7O7Nzc1aVkjEYmPKTVZHILWYKczgvu2AtcvIhI69Qgm9gvg8uDx5cDPC1y/iEjoelNpyhJGZVnRDrq8RmhRmtntwGPAMjNrMbMrgOuAVWb2ArAqeC4iMqX0ptLUVpZhFo8FGELrt7j7ZeO89Paw6hQRKQY9qeHYDAtBEU8Wi4jEVU9qKDanjoISgYhI3vWqRyAiUtp6gjmCuFAiEBHJs95UWj0CEZFS1qsegYhIaetWj0BEpHS5e9Aj0FlDIiIlaWAoQ8ahrrI86lAmTYlARCSPDqwzpB6BiEhJitvKo6BEICKSVz1KBCIipS1uS1CDEoGISF71KhGIiJQ2DQ2JiJS43tQwoB6BiEjJ6kkNAeiCMhGRUtUT9AhqK9QjEBEpSb2pNLUVSRKJeGxTCUoEIiJ5FbeVR0GJQEQkr3pitvIoKBGIiORV3HYng4gSgZn9bzPbYGbPm9knoohBRCQMcVuCGiJIBGZ2MvBh4HRgOfBOMzuu0HGIiIShJzUcqyWoAaLov5wIPO7ufQBm9jBwEfCVfFc0MDTM0HAGH+O1hBkGmEH20QE2arLfLPf4A+8bzcYqFJGSkt2vOF49gigSwQbgWjObBfQD7wDWhlHRtb/axA8f3xbGR48rm1hemzBGJ5rXvoHXJBbDch4f+BxGktHopGSQsAPve225kRi5T2QfJxMJkglIBmW598lEznFBeVniwH0yYVQkE1SUJSgfuU8YZcns8/KkUZ5MUBbclyeNyrIk1eVJqiuS1FRkH1eWJaksT1CRTFBZnqCqLF6n2olMJI5nDRU8WnffZGZfBtYAPcB6ID36ODO7CrgKYPHixYdV1/knz2XJrJpx4oCM++t6Cx4UjLyy/7k7Gc8+f/27gvLg80aOyd5nn4/VWdj/WfvrzH7GgdcO1J8JHmTcXxO75z7Pec9ILBl3hjNOxp1MBoaD5yNlI4/TmQypdLaNmVHHpEeOGXaGhjMMDWcYTGcYHM4wNDxWf+vQVZQlqCpLUF2RpLaijJrK7H1dZRnTqstpqCmnobqCGbXlTK8up6GmgoagfHp1OfVV5SSVTKQIxPGsoUiidfebgJsAzOyLQMsYx6wGVgM0Nzcf1l+bM49t5MxjG48gUjkYD5LG0LAzlMm8JlkMDTup9DB9g8MMDGbv+4eGGUxnSKUzDKaHSaUzDAxl6B8aZiC49Q4O05tK05NKs7trgC2t3ezrG6I79brfC/uZZdd2mV5dzoyaCmbVVdBYVxncKphZW8GM2gpm1mQfN9SUU1dZpuE8yauh4ey/bfUIJsHMZrt7m5ktBt4FrIwiDjlyZkZZ0ihLQjXhjosODWfo7BtiX/8Q+/oH6ewbyt76s2Vdwf2rfYPs6UmxeVc3e3tT4/ZakgnL9i6qy6mvKqOuqoz6ynLqqsqorUhSXVFGTTCkVVtZRn1VGfVV5UyrCnop1dneSFlSZ2FLVhx3J4OIEgFwZzBHMAR81N1fjSgOiZHyZIKm+kqa6isn/R53p6s/TUffIB29g7zaO0hH3yD7+obo7B9kX382mXQPZHsg7d09dA+ks72XwWEGhzMHraO+qoyGmnJm1lTQUFPBjJpg6CoYtpoeDGHNqMn2UmbVVVATo3VoZPJGlqCuVyI4OHc/K4p6pfSYGdNrypleU85RjbWH/P70cIa+oexQVfdAmu6BIbr600ECGaQzSCSdfYO82pftjby4p4fO3omHsqrLk8yeVsm86VXMn17NvIYq5jdUs3BGDQtnVLOgoZqq8nideSIHlqBWj0BkCilLJpiWTDCtqpx50w/tvenhDF0DB5JGdshqkL09g+ztSdHanWJXZz9PvNTB7q4BhjOvHcKaWZvtXcyoOTC/0Vif7VU01VfSVFfJopk1zJ1WpbOuisSBTWnilcSVCERCUpZMMLM2OzkNE/dGhjNOa9cAOzv72flqPy2v9vHKvoFsAukdYkdHH8+2dLK3Z5D0qIRRUZZgycwalsyqYfHMWhbNrGbxzBoWz6xh0cwa9SwKKI77FYMSgUhRSCaM+Q3VzG+o5rSl4x+XyTid/UO0d6do7Rpgx6t9bNvbx8t7etm2t4/f/XEv/UPDr3nPvOlVLJlVw1GNtRw3u57Tj5rJifOm6XTbEGiyWERCl0jY/l7Gsrn1r3vd3dnTM8iOV/vY0REkib29vLynl3ufb+X2J3cA2cnM5qUzWHnMLN52whyOnV1X6KZMSeoRiEjkzGz/mVWnLJ7xutdf6eznqZc7eOKlDp58qYMH/3szX/zvzRzdWMuqk+aw6qQ5vHHxDPUWDlOvEoGIFLv5DdVcuGIBF65YAGQTw/2bWlmzsZWbfvsS33nkRRpqynnL8U2cs2w2Zx/fFMxxyGRoaEhEYmd+QzXvX7mU969cStfAEA9vaefBLW08vKWdnz/zCmZwyuIZrDppDueeqCGkg+lOpfevxxUnSgQiAsC0qnL+cvl8/nL5fDIZ57md+/ifzW3cv6mV6+7ZzHX3ZIeQLjltEe87Y0nsfvUWQhz3IgAlAhEZQyJhLF/UwPJFDXxy1fH7h5B+9ewurrtnM995eCtXnnU071+5hPqqeK29H6be1DB1VfH7sxqv/ouIRGJkCOk//2YlP/u7N7FiUQPX37uFN3/5QW74nxf2ny1T6npSaWpjuHyIEoGIHJJTFs/gex88nV9cfSanLZ3BV+/7A2d/5UG++8iLDIy6hqHU9MZwCWpQIhCRw/SGhQ3cePlp/NdHz+RP5k/j2v/exNlfeZBbn9hGJpOffSriJo4b14MSgYgcoRWLGvjhFX/Gj686g8Uza7jmrg285z8eZcvu7qhDK7g4bkoDSgQikidnHD2LOz6ykn+9ZDkv7enlL771G75675aSGi7S0JCIlDwz412nLOSBv38rF6yYzw0P/pF3fOs3vNBaGr2D3tSwhoZERCC7hPa/XrKCH15xOl39aS7690d5YFNr1GGFKpNxegfT1MXwOgIlAhEJzVnHNXH3x87kqMZarvzBWr790Fbcp+ZEct/QMO7xW14ClAhEJGTzplfzk79ZyTvfMJ8v/3ozn/jPZ6bkvEFc1xkCXVksIgVQXZHkW5eu4IS59Xz1vi28tKeX776/mTnTqqIOLW/271esK4tFRMZmZnz0nGNZ/b5m/tjWwwU3/Jb1OzqjDitv9vcIdGWxiMjEVp00hzv/9k2UJRJc8p3HuHv9K1GHlBc9GhoSEZm8E+dN4xdXn8lHfrSOj93+e/7l7udpqq+iqb6SOfWVvPm4Rv7iT+dRlozPb9XeVHbeI47XEUQSsZl9ErgScOA54IPuPhBFLCISjVl1ldx65Rn84LGX2dreS3v3AG3dKTa+so871rXwlV9v4cNnHcUlpy2iJgbDLT2pIQAtQz0ZZrYA+Dhwkrv3m9lPgEuBWwodi4hEq6IswZVnHf2askzGeWBzG995eCufu3sj33jgBT7x9uP4wJlHRRTl5PSoR3BY9Vab2RBQA0yNQUIROWKJhO3fP3ndtg6+cf8LfO7ujezaN8Bn/vwEzIpzP+X9+xXrrKGDc/edwFeB7cAuYJ+731foOESk+J26ZCbf/+DpvO+MJXznkRf5zJ3PMVykK5v2ptIkDKrL4zc0VPBEYGYzgAuBo4D5QK2ZvXeM464ys7Vmtra9vb3QYYpIkUgkjM9f+Cd87G3H8p9rd3D1bU+TShffBWkjm9IUa49lIlFMyZ8LvOTu7e4+BPwMeNPog9x9tbs3u3tzU1NTwYMUkeJhZvz9ecv4v39xIvds2M0Vt6wtul3RegbiuRcBRJMItgNnmFmNZVPn24FNEcQhIjFz5VlH89WLl/PYi3u5dPVjtHenog5pv97BeG5cD9HMETwB/BR4muypowlgdaHjEJF4es+pC7nx/c1sbevl3d9+lJf29EYdEpA9a6iuqjzqMA5LJFdruPs/u/sJ7n6yu7/P3YsnrYtI0TvnhNnc9uE/o3tgiPd8+9GiWKoiuymNegQiIgXzxsUzuPNv30R1RZJLVz/OrzfsijSe3mCyOI6UCEQkto5uquNnf/cmls2t5yM/eppv3v8CmYhOL+0eiOc2laBEICIxN7u+ih9fdQbvOmUBX7//D1x9+9P0DRb+jKLsZHE8E0E8oxYRyVFVnuRrFy/nxLnT+NI9m3h5Tx83/K83cnRTXcFi6E2lY3lVMahHICJThJnx4bOP5qYPnMbOzn7+/Ju/4abfvlSQoaJUepihYdfQkIhIMThn2WzWfPJs3nxsI1/45UYuXf042/aGe4rpC609ADTVV4ZaT1iUCERkypk9rYobL2/maxcvZ9PuLs7/xm+4c11LaPXd+XQLFckE5500J7Q6wqREICJTkpnx7lMXsuaTb2HFogb+/o71XPurjXlftG4wneHnz7zCuSfNpqGmIq+fXShKBCIypc2dXsUPrjidy1cu4bu/eYkP3fIU+/qH8vb5D21po6N3kPecujBvn1loSgQiMuWVJxP8y4Un86V3/SmPbt3DRf/2O15s78nLZ/90XQuNdZWcfVx8F8dUIhCRknHZ6Yu59coz6Owf4rLvPs6Ojr4j+ry9PSn+Z3Mb7zplQaz2Vx4tvpGLiByG04+ayW0f/jMGhjL89Y1P0Np1+Nul//yZV0hnnHefEt9hIVAiEJESdMLcaXz/Q6eztyfFe298go7ewcP6nJ+ua+FPF0xn2dz6PEdYWEoEIlKSVixq4MbLT2N7Rx+X3/wk3QOHNoG88ZUuNu7qivUk8QglAhEpWSuPmcW333sKm3Z1cdG/P8qvN+zGfXKnl975dAvlSeOC5fNDjjJ88bweWkQkT952whxuvLyZz9+9kY/8aB1vWDidT606nrcc30R3Ks3Wth62tveye18/VeVJaivLqKlI8l+/38m5J85hRm08rx3IpUQgIiXvrctm8+ZjG7nr9zv55gMv8IHvPcX06vKDXm9wyWmLChRhuJQIRESAsmSCi5sXceGKBfxk7Q6ea9nH0sZajp1dxzFNtcxvqGZwOENfapieVBpwjp0d70niEUoEIiI5KsoSvPeMJWO+VlWeZFpM9yWeiCaLRURKnBKBiEiJUyIQESlxBU8EZrbMzJ7JuXWZ2ScKHYeIiGQVfLLY3bcAKwDMLAnsBO4qdBwiIpIV9dDQ24Gt7r4t4jhEREpW1IngUuD2sV4ws6vMbK2ZrW1vby9wWCIipSOyRGBmFcAFwB1jve7uq9292d2bm5riu+GDiEixs8kusJT3is0uBD7q7udN4th2YPTw0XRg30HKJno+8ji3rBHYM5n4xzBWPIdyTBzbM9FxxdaeiWKdzDH5ak/uY7VncrFO5pjJtGd0WTG3Z7zXDvU7WeLuB/8l7e6R3IAfAx88gvevPljZRM9HHo8qW5vPeA7lmDi2Z6Ljiq09R/od5as9o9qm9hSwPZNpQ7G053C+o/EeT+YWydCQmdUAq4CfHcHH3D2Jsome3z3OMfmM51COiWN7Jjqu2Noz2c8Kuz2TjWMy1J6Jyw9WVsztGe+1UL6TyIaGipGZrXX35qjjyBe1p7ipPcVtqrVnIlGfNVRsVkcdQJ6pPcVN7SluU60941KPQESkxKlHICJS4qZkIjCzm82szcw2HMZ7TzWz58zsj2b2LTOznNc+ZmZbzOx5M/tKfqOeMKa8t8fMPmdmO3PWfHpH/iOfMK5QvqPg9U+bmZtZY/4iPmhMYXxHXzCzZ4Pv5z4zK9jmuCG153oz2xy06S4za8h/5OPGFEZ7Lg7+FmTMLN5zCUdyelSx3oCzgVOADYfx3ieBlYAB9wB/HpSfA9wPVAbPZ8e8PZ8DPj2VvqPgtUXAvWSvO2mMc3uAaTnHfBz4j5i35zygLHj8ZeDLMW/PicAy4CGguVBtCeM2JXsE7v4I0JFbZmbHmNmvzWydmf3GzE4Y/T4zm0f2f77HPPtN/wD4q+DlvwWuc/dUUEdbuK04IKT2RCrENn0d+EegoJNfYbTH3btyDq2lgG0KqT33uXs6OPRxYGG4rTggpPZs8uwimrE3JRPBOFYDH3P3U4FPA/8+xjELgJac5y1BGcDxwFlm9oSZPWxmp4Ua7cEdaXsArg666Teb2YzwQp20I2qTmV0A7HT39WEHOklH/B2Z2bVmtgP4a+CzIcY6Gfn4NzfiQ2R/XUcpn+2JtZLYs9jM6oA3AXfkDCdXjnXoGGUjv8LKgBnAGcBpwE/M7OjgV0JB5ak93wa+EDz/AvA1sv9zRuJI22TZixSvITv8ELk8fUe4+zXANWb2T8DVwD/nOdRJyVd7gs+6BkgDt+YzxkORz/ZMBSWRCMj2fDrdfUVuoWX3Q1gXPP0F2T+Oud3VhcArweMW4GfBH/4nzSxDdi2SKJZGPeL2uHtrzvu+C/wyzIAn4UjbdAxwFLA++B97IfC0mZ3u7rtDjn0s+fg3l+s24FdElAjIU3vM7HLgncDbo/gRlSPf30+8RT1JEdYNWErOxBDwKHBx8NiA5eO87ymyv/pHJobeEZR/BPh88Ph4YAfBdRgxbc+8nGM+Cfw47t/RqGNepoCTxSF9R8flHPMx4Kcxb8/5wEagqdD/1sL898YUmCyOPICQvvDbgV3AENlf8leQ/bX4a2B98I/xs+O8txnYAGwFbhj5Yw9UAD8KXnsaeFvM2/ND4DngWbK/fOYVqj1htWnUMQVNBCF9R3cG5c+SXTtmQczb80eyP6CeCW6FPAsqjPZcFHxWCmgF7i1Ue/J905XFIiIlrpTOGhIRkTEoEYiIlDglAhGREqdEICJS4pQIRERKnBKBxJKZ9RS4vhvN7KQ8fdZwsKLoBjO7+2CrcJpZg5n9XT7qFhmLTh+VWDKzHnevy+PnlfmBBdFClRu7mX0f+IO7XzvB8UuBX7r7yYWIT0qPegQyZZhZk5ndaWZPBbczg/LTzexRM/t9cL8sKP+Amd1hZncD95nZW83sITP7abBu/q05a88/NLLmvJn1BIvBrTezx81sTlB+TPD8KTP7/CR7LY9xYNG8OjN7wMyetuz69xcGx1wHHBP0Iq4Pjv2HoJ5nzexf8vifUUqQEoFMJd8Evu6ulAK3AAACLklEQVTupwHvBm4MyjcDZ7v7G8mu4PnFnPesBC5397cFz98IfAI4CTgaOHOMemqBx919OfAI8OGc+r8Z1H/Q9WiCdW3eTvbKboAB4CJ3P4Xs/hdfCxLRZ4Ct7r7C3f/BzM4DjgNOB1YAp5rZ2QerT2Q8pbLonJSGc4GTclaTnGZm9cB04PtmdhzZlSPLc96zxt1z16l/0t1bAMzsGbLr0/x2VD2DHFikbx2wKni8kgN7I9wGfHWcOKtzPnsdsCYoN+CLwR/1DNmewpwx3n9ecPt98LyObGJ4ZJz6RCakRCBTSQJY6e79uYVm9v+AB939omC8/aGcl3tHfUYq5/EwY/8/MuQHJtfGO2Yi/e6+wsymk00oHwW+RXbPgSbgVHcfMrOXgaox3m/Al9z9O4dYr8iYNDQkU8l9ZNfsB8DMRpYYng7sDB5/IMT6Hyc7JAVw6cEOdvd9ZLeg/LSZlZONsy1IAucAS4JDu4H6nLfeC3woWFMfM1tgZrPz1AYpQUoEElc1ZtaSc/sU2T+qzcEE6kayS4cDfAX4kpn9DkiGGNMngE+Z2ZPAPGDfwd7g7r8nu/rlpWQ3amk2s7Vkewebg2P2Ar8LTje93t3vIzv09JiZPQf8lNcmCpFDotNHRfIk2CWt393dzC4FLnP3Cw/2PpGoaY5AJH9OBW4IzvTpJMKtP0UOhXoEIiIlTnMEIiIlTolARKTEKRGIiJQ4JQIRkRKnRCAiUuKUCEREStz/BwgclPNoYCW3AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>seq2seq_acc</th>\n",
       "      <th>bleu</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.244290</td>\n",
       "      <td>6.343948</td>\n",
       "      <td>0.388536</td>\n",
       "      <td>0.354548</td>\n",
       "      <td>00:47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.042745</td>\n",
       "      <td>3.911313</td>\n",
       "      <td>0.525344</td>\n",
       "      <td>0.378933</td>\n",
       "      <td>00:50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.876625</td>\n",
       "      <td>5.006873</td>\n",
       "      <td>0.409836</td>\n",
       "      <td>0.372162</td>\n",
       "      <td>00:48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.989081</td>\n",
       "      <td>3.710540</td>\n",
       "      <td>0.503919</td>\n",
       "      <td>0.409202</td>\n",
       "      <td>00:48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.804112</td>\n",
       "      <td>4.398979</td>\n",
       "      <td>0.427331</td>\n",
       "      <td>0.381098</td>\n",
       "      <td>00:47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.949583</td>\n",
       "      <td>4.069941</td>\n",
       "      <td>0.449399</td>\n",
       "      <td>0.394692</td>\n",
       "      <td>00:46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>1.774466</td>\n",
       "      <td>3.915257</td>\n",
       "      <td>0.452546</td>\n",
       "      <td>0.394610</td>\n",
       "      <td>00:47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>1.925855</td>\n",
       "      <td>3.910456</td>\n",
       "      <td>0.449511</td>\n",
       "      <td>0.390513</td>\n",
       "      <td>00:46</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(8, 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='152' class='' max='152', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [152/152 00:16<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "inputs, targets, outputs = get_predictions(learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos pour quelle raison demandez - vous aux émetteurs des renseignements qui n'ont pas à être fournis sur les reçus papier remis aux contribuables ?,\n",
       " Text xxbos why are your requiring xxunk to provide information that is not required to be on the paper receipts given to clients ?,\n",
       " Text xxbos why do you need to support the information to the the application of the claim ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[700], targets[700], outputs[700]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos quels facteurs sont responsables des différences de concentrations des contaminants présents dans les poissons dans les cours d’eau et les lacs du nord ?,\n",
       " Text xxbos what factors are responsible for the differences in the level of contaminants found fish in northern rivers and lakes ?,\n",
       " Text xxbos what factors are the in the in the north and in the north - based production ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[701], targets[701], outputs[701]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos en quoi consiste la politique des retombées industrielles et régionales ( rir ) ?,\n",
       " Text xxbos what is the industrial and regional benefits ( irb ) policy ?,\n",
       " Text xxbos what is the policy policy ( policy ) ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[4001], targets[4001], outputs[4001]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#get_bleu(learn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Attention is a technique that uses the output of our encoder: instead of discarding it entirely, we use it with our hidden state to pay attention to specific words in the input sentence for the predictions in the output sentence. Specifically, we compute attention weights, then add to the input of the decoder the linear combination of the output of the encoder, with those attention weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_param(*sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqQRNN(nn.Module):\n",
    "    def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n",
    "                 p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n",
    "        super().__init__()\n",
    "        self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n",
    "        self.emb_enc = emb_enc\n",
    "        self.emb_enc_drop = nn.Dropout(p_inp)\n",
    "        self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc, bidirectional=True)\n",
    "        self.out_enc = nn.Linear(2*n_hid, emb_enc.weight.size(1), bias=False)\n",
    "        self.hid_dp  = nn.Dropout(p_hid)\n",
    "        self.emb_dec = emb_dec\n",
    "        emb_sz = emb_dec.weight.size(1)\n",
    "        self.decoder = QRNN(emb_sz + 2*n_hid, emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n",
    "        self.out_drop = nn.Dropout(p_out)\n",
    "        self.out = nn.Linear(emb_sz, emb_dec.weight.size(0))\n",
    "        self.out.weight.data = self.emb_dec.weight.data #Try tying\n",
    "        self.enc_att = nn.Linear(2*n_hid, emb_sz, bias=False)\n",
    "        self.hid_att = nn.Linear(emb_sz, emb_sz)\n",
    "        self.V =  init_param(emb_sz)\n",
    "        self.pr_force = 0.\n",
    "        \n",
    "    def forward(self, inp, targ=None):\n",
    "        bs,sl = inp.size()\n",
    "        hid = self.initHidden(bs)\n",
    "        emb = self.emb_enc_drop(self.emb_enc(inp))\n",
    "        enc_out, hid = self.encoder(emb, hid)\n",
    "        \n",
    "        hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()\n",
    "        hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))\n",
    "\n",
    "        dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n",
    "        res = []\n",
    "        enc_att = self.enc_att(enc_out)\n",
    "        for i in range(self.max_len):\n",
    "            hid_att = self.hid_att(hid[-1])\n",
    "            u = torch.tanh(enc_att + hid_att[:,None])\n",
    "            attn_wgts = F.softmax(u @ self.V, 1)\n",
    "            ctx = (attn_wgts[...,None] * enc_out).sum(1)\n",
    "            emb = self.emb_dec(dec_inp)\n",
    "            outp, hid = self.decoder(torch.cat([emb, ctx], 1)[:,None], hid)\n",
    "            outp = self.out(self.out_drop(outp[:,0]))\n",
    "            res.append(outp)\n",
    "            dec_inp = outp.data.max(1)[1]\n",
    "            if (dec_inp==self.pad_idx).all(): break\n",
    "            if (targ is not None) and (random.random()<self.pr_force):\n",
    "                if i>=targ.shape[1]: break\n",
    "                dec_inp = targ[:,i]\n",
    "        return torch.stack(res, dim=1)\n",
    "    \n",
    "    def initHidden(self, bs): return one_param(self).new_zeros(2*self.n_layers, bs, self.n_hid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_enc = torch.load(path/'models'/'fr_emb.pth')\n",
    "emb_dec = torch.load(path/'models'/'en_emb.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)\n",
    "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))],\n",
    "                callback_fns=partial(TeacherForcing, end_epoch=8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
     ]
    }
   ],
   "source": [
    "learn.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEKCAYAAAAB0GKPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl4HNWZ7/Hvq32xLFuybEvyivFuvIAw20BYHcIABpIQkwkPWQiTmYQkJLlzb27uQC6ZzCSEzCSZTCZxILlkJpCwx4Q9CxAWA/K+gBds8CIvsmRb1q6W3vtHl0wjJLltqdXV0u/zPP2o+9Spqve4bb0+darOMXdHRETkWNKSHYCIiKQGJQwREYmLEoaIiMRFCUNEROKihCEiInFRwhARkbgoYYiISFyUMEREJC5KGCIiEpeMZAfQn0aNGuWTJk1KdhgiIiljxYoVB9y9JJ66gyphTJo0icrKymSHISKSMszsnXjr6pKUiIjERQlDRETiooQhIiJxUcIQEZG4KGGIiEhclDBERCQuShgiIhIXJQwRkRT27MZ9/PT5twbkXEoYIiIp7OkNe7nn5bcH5FxKGCIiKaymvoXiYVkDci4lDBGRFFbT0EpxfvaAnEsJQ0QkhdXUt6qHISIivXN3ahpaGDVMPQwREelFY2s7zW0dFOWrhyEiIr2oqW8FoFgJQ0REenOgoQVAl6RERKR3R3sYGvQWEZHe1NRHexjF6mGIiEhvaho0hiEiInE4UN/CsOwMcjLTB+R8ShgiIimqtqF1wG6phQQnDDP7kpmtN7MNZvblbrafb2aHzWx18Lo1ZtulZrbJzLaa2f9KZJwiIqloIJ/yBshI1IHNbA7wWWAh0Ao8ZWaPu/uWLlX/4u6Xd9k3HfgP4BJgF/C6mS1z942JildEJNUcqG9h3Mi8ATtfInsYM4Hl7t7o7hHgeeDqOPddCGx1923u3gr8BlicoDhFRFJSTUMrowawh5HIhLEeOM/Mis0sD7gMGN9NvbPMbI2ZPWlms4OycmBnTJ1dQdn7mNlNZlZpZpXV1dX9Gb+ISGh1dDi1DYPkkpS7v2Fm3wWeBeqBNUCkS7WVwER3rzezy4BHgamAdXfIHs6zFFgKUFFR0W0dEZHB5nBTG+0dPmBTm0OCB73d/W53P9XdzwNqgS1dtte5e33w/gkg08xGEe1RxPZGxgFViYxVRCSV1DR0PrQ3OC5JYWajg58TgGuA+7psH2tmFrxfGMRTA7wOTDWzyWaWBSwBliUyVhGRVHIgmBZkoOaRggRekgo8ZGbFQBvweXc/aGafA3D3nwIfAf7OzCJAE7DE3R2ImNkXgKeBdOAX7r4hwbGKiKSM2uAp74F8DiOhCcPdz+2m7Kcx738M/LiHfZ8AnkhcdCIiqevdeaQGySUpERFJjM5LUkV5ShgiItKLmoYWRuZlkpE+cL/GlTBERFJQdFqQgRvwBiUMEZGUVFPfOmDTmndSwhARSUEHGloG9JZaUMIQEUlJAz1TLShhiIiknLb2Dg43tQ3oMxighCEiknIOdi7NqktSIiLSm6PTgqiHISIivXl34kH1MEREpBc19Z2XpNTDEBGRXhwI5pEaNYBrYYAShohIyqlpaCUjzRiem+gJx99LCUNEJMXU1LdQlJ9FsJzQgFHCEBFJMdG1vAf2chQoYYiIpJwD9a2MGuABb1DCEBFJOTUNLQM+8SAoYYiIpJxkTG0OShgiIimlsTVCY2v7gD+DAUoYIiIppebotCCDrIdhZl8ys/VmtsHMvtzN9r8xs7XB62Uzmxez7W0zW2dmq82sMpFxioikipqG5DzlDZCwpz7MbA7wWWAh0Ao8ZWaPu/uWmGrbgQ+4+0Ez+xCwFDgjZvsF7n4gUTGKiKSamuAp74Ge2hwS28OYCSx390Z3jwDPA1fHVnD3l939YPBxOTAugfGIiKS8zh7GQK+2B4lNGOuB88ys2MzygMuA8b3U/wzwZMxnB54xsxVmdlMC4xQRSRnJmngQEnhJyt3fMLPvAs8C9cAaINJdXTO7gGjC+KuY4nPcvcrMRgPPmtmb7v5CN/veBNwEMGHChH5uhYhIuNTUt5CbmU5e1sDOIwUJHvR297vd/VR3Pw+oBbZ0rWNmc4G7gMXuXhOzb1Xwcz/wCNGxkO7OsdTdK9y9oqSkJBHNEBEJjZqGgV/Lu1Oi75IaHfycAFwD3Ndl+wTgYeB6d98cU55vZgWd74FFRC9xiYgMaQfqW5Ly0B4k8JJU4CEzKwbagM8Hd0N9DsDdfwrcChQDPwlmXYy4ewUwBngkKMsA7nX3pxIcq4hI6NXUt1JamJOUcyc0Ybj7ud2U/TTm/Y3Ajd3U2QbM61ouIjLU1TS0MKd8eFLOrSe9RURShLtT29BKURKe8gYlDBGRlFHXHKGt3ZMytTkoYYiIpIzOp7wH5V1SIiLSf47OI6VLUiIi0ps399QBMKk4PynnV8IQEUkRy7fVUlaYw/ii3KScXwlDRCQFuDvLt9Vw5knFBM+oDTglDBGRFLB1fz01Da2ccVJR0mJQwhARSQHLt0Wn2jvzpOKkxaCEISKSApZvq6W0MIcJRXlJi0EJQ0Qk5MIwfgFKGCIiodc5fnFmEscvQAlDRCT0lm+vBZI7fgFKGCIiobd8W03Sxy9ACUNEJNTcnVdDMH4BShgiIqH2VnU9B+qTP34BShgiIqH2yrZwjF+AEoaISKiFZfwClDBEREIrTOMXoIQhIhJaYRq/ACUMEZHQCtP4BUBGIg9uZl8CPgsY8HN3/0GX7Qb8ELgMaAQ+6e4rg203AP8nqPpP7n5PouL83tNvEulw0sxIM0gz67X717nFDNzBAdyjP4nun54WfWWkWXC8d7dZcI40ix4kLeazmXWJI1pm79n3vWXpaUZampEee9706Lkz0tLITH+3vLN+Rsw+aTGxZgTvw9D9FRnqwjR+AQlMGGY2h2iyWAi0Ak+Z2ePuviWm2oeAqcHrDOA/gTPMrAi4Dagg+vt4hZktc/eDiYj13ld30Njajjt0uAev4z9O5+9YP4F9wyYjSDpZ6WnkZKaTm5VOTkY6OVnpZKenkZURvIL32RnRejmZaWRnpJOdEVMnqJeXlUFedjr5WRnkBz+H5WQwLDuD7Iw0JSmRLl7fXsvZU8IxfgGJ7WHMBJa7eyOAmT0PXA3cEVNnMfArd3dguZmNMLNS4HzgWXevDfZ9FrgUuC8Rga66ddFx7+PuuHO0B9B1W3uHE+mIJp/2jmjvwx0IkpLzbnLCoT04XkfMz9j9PEhiTvf1oj+hvSP6ua2jg/Z2J9LRQVu7Hy1vd6ejS2yd+0c6nPZ2p63DibR30N7htEQ6aG5rp6mtPfjZQWukncbWCIeaOmiNRF+d9Tp/Hm/CzUizo8ljeE4mBTkZDM/NZHhOJkX5mYwalk1JQfbRn2WFuQzPzQjNPySR/tbe4ew/0sKkUclZjrU7iUwY64Fvm1kx0ET0slNllzrlwM6Yz7uCsp7KQ8NiLjN1ty0j3chIH9iYwsKD5NOZTFrbO2hp66CprZ2G1giNLe3Ut0RoaInQ0BqhviVCfXP055HmCEea26hrjrCztpG6pjZqGlppiXS87zx5WemUjciltDCH0sIcRhfkMHp4NqMLsikpiJaNGZ5DepqSiqSe+uYIAMNzMpMcybsSljDc/Q0z+y7wLFAPrAEiXap19y/Zeyl/HzO7CbgJYMKECSccr/QfMyMz3chMTyM/u+/Hc3fqWyIcqG/lQH0L++ta2HO4iapDzVQdamLP4SY27T3CgfqW9/Vs0tOMscNzKB+RS/nIXMYX5TGpOI+JxXlMLM6nOD9LvRQJpbrmNgCG5w6BhAHg7ncDdwOY2T8T7SnE2gWMj/k8DqgKys/vUv5cD+dYCiwFqKioGASjB9KVmVGQk0lBTiaTe+met3c4NQ3RhFJ9pIU9h5vZfaiRqkPN7D7UxGvba3l09e73jDENy85gYnEek4rzmTQqmkQqJo7kpJJhA9AykZ4dbgoSRk5Cf00fl0TfJTXa3feb2QTgGuCsLlWWAV8ws98QHfQ+7O57zOxp4J/NbGRQbxHw9UTGKqkvPc2il6UKcnqs0xJpZ9fBJnbUNPJ2TQPvBD83VB3mqQ17aQ+6KDNLh3PFvFIuP6WMCcXhuENFhpYh18MAHgrGMNqAz7v7QTP7HIC7/xR4gujYxlait9V+KthWa2bfAl4PjnN75wC4SF9kZ6QzpWQYU7rpQbS1d7CztpHnNlXz2Noq7nhqE3c8tYl54wpZNHssF84YzYyxBbqEJQOiril8Yxjmg+Ee0EBFRYVXVnYdVxc5MbsONvL42j08sW4Pa3YdBqCsMIcLZ47mgumjOeOkYoZlh+dygQwu91fu5B8eXMuL//MCxo1MXC/XzFa4e0U8dfW3XaQH40bm8bcfmMLffmAK++uaeW5TNX98cx8Pr9zNfy/fQXqaMW9cIWdPGcXZU4o5bdJIsofqrXHS7+qaht4lKZFBYfTwHK49fTzXnj6elkg7lW8f5OW3DvDyWzX85/Nv8eM/b6UgO4NFs8dyxbxSzjl5FJnpmnlHTlxdcwQzGJYVnl/T4YlEJEVkZ6RzzsmjOOfkUQAcaW7j1W21PL1hL09t2MtDK3dRlJ/FpXPGcvaUYuaPH0H5iFyNfchxqWtqoyA7g7QQPUekhCHSRwU5mVw8awwXzxrDP109h+c3VfPY2j08snI39766A4BRw7KZP76QBRNGcsH00cws1eC59K6uqS1Ul6NACUOkX2VnpLNo9lgWzR5La6SDN/fWsWbnIVbtPMSanYf4wxv7+d7Tm44Onl80cwwLJxWRr8Fz6aKuuS1Ud0iBEoZIwmRlpDF33AjmjhvB9cETSPuPNPPnN/fzxzf289CK6OA5RHsgE4vzmFiUx4TiPKaUDGPqmGFMHpWvgfQhqq4pwvDccP2KDlc0IoPc6IIcPnb6BD52+gSa29pZvq2GDVV17Khp5J3aBpZvq+GRmKfR09OMiUV5zCgt4Pzpo7loxmiKh/XDfCsSenXNbaGZ1ryTEoZIkuRkpnP+9NGcP330e8qb29rZfqCBLfvr2brvCFv217NqxyGeWLeXNIOKiUVcMmsMH5w9Vk+hD2IawxCRY8rJTGdm6XBmlg4/WububKiq45mN+3hmw16+/cQbfPuJN5g3fgRXzC3l8rlljC3seUoUST11zZHUHMMwsynALndvMbPzgblE17E4lMjgRCTKzJhTXsic8kK+csk0dtY28sS6PTy2top/ejyaPBZOKmLJwvFcPrdMz4CkuEh7B/Ut4RvDiPdv1UNAu5mdTHT22cnAvQmLSkR6Nb4o+hT6728+lz999QN8+aJpVB9p4ZbfruHc7/6Znz7/1tHZTiX11LeEbx4piP+SVIe7R8zsauAH7v7vZrYqkYGJSHxOKhnGly6eys0Xnszzm6v5+V+28Z0n3+RHf9zCR04bx/nTSzhtYhGFIbseLj07OvFgyL6zeBNGm5ldB9wAXBGUhaslIkNcWppxwYzRXDBjNBuqDnP3i9v5zWs7+dUr72AGM8YO54zJRfzVyaM4b1oJWRm6bBVWR6c2D9FaGBB/wvgU8Dng2+6+3cwmA/+duLBEpC9mlxXyr9fO59tXncKqnQd5fftBXnu7ht++vpP/9/LbjMzL5Mp5ZVxz6jjmjivUU+chE8aJByHOhOHuG4EvAgSLGhW4+3cSGZiI9F1uVnowm+4oYCqtkQ5e3FrNQyt3c9/rO7nnlXc4qSSfjy+cwJKFEzRde0i828NIwYRhZs8BVwb1VwPVZva8u38lgbGJSD/LykjjwhljuHDGGA43tfHkuj08sGIX//T4G/zwj1v4+MIJfPKcSZQW5iY71CHt3TGMcCXweKMpdPc6M7sR+KW732ZmaxMZmIgkVmFuJkuCnsXqnYf4+V+28fO/bOPuF7fz13NL+chp4zjrpGIydIvugAvj8qwQf8LIMLNS4FrgGwmMR0SSYP74EfzHx09lZ20jv3zpbe6v3MnvVldRnJ/FZaeUcsW8MiomjgzVVNuDWV1TW+jWwoD4E8btwNPAS+7+upmdBGxJXFgikgzji/K49YpZ/MOl03lu034eW7OH+yt38l/L36G0MIcr55dx1fzy9zyFLv2vrjkSurUwIP5B7weAB2I+bwM+nKigRCS5cjLTuXROKZfOKaW+JcIfNu5j2Zoq7v7Ldn72/Damjylg8YIyrlkwTlOSJEAY55GC+Ae9xwH/DpwDOPAi8CV335XA2EQkBIZlZ3DVgnKuWlBOTX0LT6zbw6Orq7jjqU18/5nNLJo1huvPnMhZU4p1e24/CeNaGBD/JalfEp0K5KPB508EZZf0tpOZ3QLcSDTJrAM+5e7NMdv/Dbgg+JgHjHb3EcG29mAfgB3ufmWcsYpIghQPy+b6syZx/VmTeKemgXtf28H9r+/kyfV7mVKSz/VnTuRjp08gN0trePRFGNfCgPjnkipx91+6eyR4/T+gpLcdzKyc6LMbFe4+B0gHlsTWcfdb3H2+u88n2oN5OGZzU+c2JQuR8JlYnM/XPzSTV75+Ed//6DwKcjL55mMbOf/OP3PvqzuItHckO8SUFdYeRrwJ44CZfcLM0oPXJ4CaOPbLAHLNLINoD6Kql7rXAffFGY+IhEROZjofPm0cj37+HO7/27MYNzKP//3IOhb94AWeXLcH71wNSuIW1jGMeBPGp4neUrsX2AN8hOh0IT1y993AncCOYJ/D7v5Md3XNbCLRGXD/FFOcY2aVZrbczK7q6TxmdlNQr7K6ujrO5ohIIiycXMSDnzuLpdefRroZf/frlVzzny+zdpdWQjgeYVwLA+JMGO6+w92vdPcSdx/t7lcB1/S2TzCFyGKiiaAMyA96Jt1ZAjzo7u0xZRPcvQL4OPCDYE2O7mJb6u4V7l5RUtLrVTIRGQBmxqLZY3nqy+dxx4fnsutgE4v/4yW+/vBaahtakx1e6IV1LQyIv4fRnWNNC3IxsN3dq929jej4xNk91F1Cl8tR7l4V/NwGPAcs6EOsIjLA0tOMa08fz5+++gE+c85k7q/cxQV3Psd/LX+H9g5dpupJWNfCgL4ljGPdP7cDONPM8ix6r91FwBvvO4jZdGAk8EpM2Ugzyw7ejyJ6O+/GPsQqIklSkJPJ/7l8Fk9+6Vxmlw3nHx9dz8d/vly9jR6EdS0M6FvC6PW/CO7+KvAgsJLo7bFpwFIzu93MYu96ug74jb93ZGwmUGlma4A/A98JZswVkRQ1bUwBv77xDO786DxW7TzE4v94kU17jyQ7rNAJ61oYcIznMMzsCN0nBgOOOZ2lu98G3Nal+NYudb7ZzX4vA6cc6/giklrMjI+cNo6TRw/js7+q5JqfvMSPrlvARTPHJDu00AjrWhhwjB6Guxe4+/BuXgXuHr70JyIpYf74ESz7wjmcVDKMG39Vyc+ef0u33wbCuhYG9O2SlIjICSstzOX+vz2Ly04p5V+efJPblm2gQ4PhoV0LA+KfGkREpN/lZqXz4+sWUD4il6UvbKO+JcIdH547pNfgCOtaGKCEISJJZmZ8/UMzKMjO4PvPbqaxpZ0fXjef7IyhOR9VWNfCAF2SEpEQMDNuvmgqt14+i6c27OXGeyppbI0kO6ykCOtaGKCEISIh8um/mswdH57LS1sPcMMvXhuSSSOs80iBEoaIhMy1p4/nR9ctYMU7B/n7X6+kbYjNehvWmWpBCUNEQujyuWV8++pTeG5TNV97YM2QunsqrGthgAa9RSSkrls4gdqGVr739CZG5mVx2xWzhsSKfnXNbUwoykt2GN1SwhCR0Pr786dQ29DK3S9upzg/i5svmprskBIuzGMYShgiElpmxjcum8nBhla+/+xmSgqyWbJwQrLDSqiwroUBGsMQkZBLSzO++5G5nDethFt/t4ENVYeTHVLChHktDFDCEJEUkJmexg8+Np+R+ZncfO8qGloG5+22YV4LA5QwRCRFFOVn8cMlC3i7poF//N36ZIeTEGFeCwOUMEQkhZx5UjE3XziVh1fu5qEVu5IdTr8L81oYoIQhIinm5gtPZuHkIv7xd+vZVl2f7HD6VZjXwgAlDBFJMRnpafxwyXyyMtL4wr2raG5rT3ZI/SbMa2GAEoaIpKDSwlzu/Mg8Nu6p4+cvbEt2OP0mzGthgBKGiKSoi2eN4ZJZY/jZC9uobWhNdjj9IsxrYYAShoiksH/44HQaWyP8+5+2JDuUfhHmtTAgwQnDzG4xsw1mtt7M7jOznC7bP2lm1Wa2OnjdGLPtBjPbErxuSGScIpKapo4p4NqK8fz38nfYWduY7HD6LMxrYUACE4aZlQNfBCrcfQ6QDizppupv3X1+8Lor2LcIuA04A1gI3GZmIxMVq4ikri9fPI00M77/zKZkh9JnYZ5HChJ/SSoDyDWzDCAPqIpzvw8Cz7p7rbsfBJ4FLk1QjCKSwsYW5vDpv5rMo6urWL87tacNCfNaGJDAhOHuu4E7gR3AHuCwuz/TTdUPm9laM3vQzMYHZeXAzpg6u4IyEZH3+dwHpjAiL5M7nk7tXkaY18KAxF6SGgksBiYDZUC+mX2iS7XHgEnuPhf4A3BP5+7dHLLbFVTM7CYzqzSzyurq6v4JXkRSSmFuJl+44GRe2FzNS1sPJDucE3a4aYj2MICLge3uXu3ubcDDwNmxFdy9xt1bgo8/B04L3u8CxsdUHUcPl7Pcfam7V7h7RUlJSb82QERSxyfOnEj5iFy+8+SbuKfmCn11zUN3DGMHcKaZ5Vl0mayLgDdiK5hZaczHK2O2Pw0sMrORQU9lUVAmItKtnMx0brlkGut2H+bpDXuTHc4JqRuqPQx3fxV4EFgJrAvOtdTMbjezK4NqXwxuu11D9I6qTwb71gLfAl4PXrcHZSIiPbp6QTknleTzr89upj3F1gGPtHfQ0No+NMcwANz9Nnef4e5z3P16d29x91vdfVmw/evuPtvd57n7Be7+Zsy+v3D3k4PXLxMZp4gMDulpxi0XT2Pzvnp+vzbemzLD4UhzdFqQwiF6SUpEZMD99SmlzBhbwA//sIVIe0eyw4lb2CceBCUMERlk0tKML188jW0HGnhk1e5khxO3sC+eBEoYIjIIfXD2GOaUD+dHf9pCayQ1ehlhXzwJlDBEZBAyM756yXR21jbxwIqdx94hBMK+eBIoYYjIIHX+9BJOnTCCH/9pa0osshT2qc1BCUNEBikz46uLprPncDP3vbYj2eEc09ExDF2SEhEZeGdPKeb0SSP5xUvb6Qj5cxl1zW2kGeSHdC0MUMIQkUHMzPibMyays7aJ5dtrkh1Or+qa2ijIyQztWhighCEig9ylc8ZSkJPB/a+He/D7YGNbqJ/yBiUMERnkcjLTWTy/jCfX7+VwcCdSGG07UM+k4vxkh9ErJQwRGfQ+VjGBlkgHy9aEc7qQjg5n6/56po0pSHYovVLCEJFBb075cGaMLQjtZamdBxtpbutg2phhyQ6lV0oYIjLomRkfO30863YfZmNVXbLDeZ/N++oBmKoehohI8l01v5ys9DTurwxfL2PzviMATB2tHoaISNKNzM/iktljeHT1bloi4Xrye8u+I5QW5lAQ4plqQQlDRIaQj1WM51BjG89u3JfsUN5j87760F+OAiUMERlCzjl5FOUjcvltiAa/2zuct6rrmRbyy1GghCEiQ0h6mvHh08bx4tYD7D7U1Ofj7T7UxLd+v7FPkxvuqG2kJdIR+ltqQQlDRIaYj542Dnd4ZOWuPh/rmQ17ufvF7Tzah4Wajg54h/yWWlDCEJEhZnxRHgsnFfHo6irc+zYhYVXQS7nrxROf3HDL0YShHoaISOgsXlDG1v31bOjjMxlVh5oB2Lq/nue3VJ/QMTbvq6d8RC7DssM9jxQkOGGY2S1mtsHM1pvZfWaW02X7V8xso5mtNbM/mtnEmG3tZrY6eC1LZJwiMrRcNqeUzHTjd6v7tub37kNNnDG5iDHDs7nrL9tO6Bhb9tenxOUoSGDCMLNy4ItAhbvPAdKBJV2qrQq2zwUeBO6I2dbk7vOD15WJilNEhp6R+Vl8YFoJy9ZU0d6HdTJ2H2piUnE+nzx7Mi9trTnup8iP3iGVApejIPGXpDKAXDPLAPKA98z85e5/dvfG4ONyYFyC4xERAWDx/HL21bXw6gmuk9ESaaf6SAtlI3L5+MIJ5GWlc9eLx9fLeKemgdZIR+if8O6UsITh7ruBO4EdwB7gsLs/08sunwGejPmcY2aVZrbczK5KVJwiMjRdPHMM+Vnp/G7Vic1gu/dwdPyibEQOhXmZXFsxnsfWVLGvrjnuY3TOITXkexhmNhJYDEwGyoB8M/tED3U/AVQA34spnuDuFcDHgR+Y2ZQe9r0pSCyV1dUnNugkIkNPblY6H5w9lifW7zmh5yg6n+MoH5ELwKfOmUSkw7nn5bfjPkbnHVInD/UeBnAxsN3dq929DXgYOLtrJTO7GPgGcKW7t3SWu3tV8HMb8BywoLuTuPtSd69w94qSkpL+b4WIDFqLF5RzpDnCc5v2H/e+nXdIlQUJY2JxPh+cNZZfv7qDxtZIXMfYvD96h1R+CtwhBYlNGDuAM80sz8wMuAh4I7aCmS0AfkY0WeyPKR9pZtnB+1HAOcDGBMYqIkPQOVOKGTUsi0dP4LJU5zMYYwvfvfnzxnMnc7ipjQdXxPdQ4JZ9R0K/BkasRI5hvEr0zqeVwLrgXEvN7HYz67zr6XvAMOCBLrfPzgQqzWwN8GfgO+6uhCEi/SojPY3L55bxp037j3v51qpDTYwalk1OZvrRstMmjmTe+BHc/eL2Y959FWnvYFt1Q8qMX0CC75Jy99vcfYa7z3H36929xd1vdfdlwfaL3X1M19tn3f1ldz/F3ecFP+9OZJwiMnQtnl9Ga6SDp9fvPa79dh9qonzEex4tw8z47LmTeaemkT++0fuMuG/XNNLa3pEST3h30pPeIjKkzR8/gonFeTx6nA/xVR1qOjp+EevS2WMpH5HLXS9u73X/zgFvXZISEUkRZsbi+eW8sq3m6K2yx+LuVB1q7jZhZKSn8alzJvHa9lrW7jrU4zE6b6lNlTukQAlDRISrF5TjTty9jEONbTS1tXebMACuPX08w7IzuLuXXsZDAqFtAAAMb0lEQVTm/UcYX5RLXlZq3CEFShgiIkwelc+CCSN4eOWuuGaw7foMRlfDczL52OnjeXztnqN3U3W1Zd8Rpo1OnfELUMIQEQHgmlPHsXlffDPYHithAHzy7El0uHPPK2+/b1tbewfbDzSk1IA3KGGIiABwxdzoDLaPxLEYUmevoazLXVKxxhfl8aE5pdz76g4aWt77IN/bBxpoa/eUGvAGJQwREQBG5GVx0Ywx/G71biLtHb3WrTrURHZGGkX5Wb3W+8y5kznSHOGByuga4u7O85ur+dJvVgMwu6ywf4IfIEoYIiKBq08t50B9K3/ZcqDXelWHmikfkUt0EouenTphJKdOGMEvXnqbVTsO8jd3vcoNv3iNuuY2fnTdAqaP1SUpEZGUdMH00YzIy+ThY1yW2t3DMxjdufHck9hR28jVP3mZN/ce4bYrZvHHr36AK+eV9UfIAyp17ucSEUmwrIw0rpxXxm9f30ldcxvDczK7rVd1qInzp8c32emiWWO4an4ZE4rz+ey5kyno4ZipQD0MEZEYVy8opyXSwZPr9nS7vSXSzv5g4aR4ZKSn8YMlC/jKJdNSOlmAEoaIyHvMHz+Ck0bl8/DK7i9L7TscXYUh3oQxmChhiIjEMDOuObWcV7fXsrO28X3b43kGY7BSwhAR6WLx/HIAHu1m8PvdZzCUMEREhrzxRXksnFzEI6t3v2+qkM6EUVrY80N7g5UShohINxbPL2NbdQMb97x3qpCqw02MGpb1noWThgolDBGRblw2p5SMNGPZmvcu37rrYPzPYAw2ShgiIt0YmZ/FuVNH8fs1e+iIWW616lDTkBzwBiUMEZEeXTGvjN2Hmli18yDQ+8JJQ4EShohIDxbNHkt2RhrLVkcvSx1r4aTBTglDRKQHw7IzuGjmaB5ft4dIe0fMMxhD7w4pSHDCMLNbzGyDma03s/vMLKfL9mwz+62ZbTWzV81sUsy2rwflm8zsg4mMU0SkJ1fMLeNAfSvLt9UO6WcwIIEJw8zKgS8CFe4+B0gHlnSp9hngoLufDPwb8N1g31lB3dnApcBPzGzo3cMmIkl3wYzRDMvOYNma3UoYCT5+BpBrZhlAHlDVZfti4J7g/YPARRadYH4x8Bt3b3H37cBWYGGCYxUReZ+czHQWzR7DU+v38nZNI1kZaRQfY+GkwSphCcPddwN3AjuAPcBhd3+mS7VyYGdQPwIcBopjywO7gjIRkQF35bwy6pojPLp6d1wLJw1WibwkNZJoT2EyUAbkm9knulbrZlfvpby789xkZpVmVlldXd2XkEVEunXOyaMYmZfJoca2XtfxHuwSeUnqYmC7u1e7exvwMHB2lzq7gPEAwWWrQqA2tjwwjvdfzgLA3Ze6e4W7V5SUxLegiYjI8chMT+OyU0oBKCscmuMXkNiEsQM408zygnGJi4A3utRZBtwQvP8I8CePzvS1DFgS3EU1GZgKvJbAWEVEetW5pOpQHfCGBC7R6u6vmtmDwEogAqwClprZ7UCluy8D7gb+y8y2Eu1ZLAn23WBm9wMbg30/7+7tiYpVRORYTp9UxBcvPJkr56feWtz9xbpO3ZvKKioqvLKyMtlhiIikDDNb4e4V8dTVk94iIhIXJQwREYmLEoaIiMRFCUNEROKihCEiInFRwhARkbgoYYiISFyUMEREJC6D6sE9M6sG3gk+FhKd/bar7spjy7pu727bKOBAP4TcU4zHW+9YbeqprKd2x74PU1sT+Z1C/7RV3+nx1TvR77Tr5zB8p8eqG29bB7qdE909von43H1QvoCl8ZbHlnXd3t02olObJCzG4613rDYdT9u6eR+atibyO+2vtuo7HZjvdCDbGm87+6utyfxOj/UazJekHjuO8sd62d7btr6K93jHqnesNvVU1lPb+rudx3PM3urpOz122VD5Trt+DsN3eqy68bY1md9prwbVJamBYmaVHufcK6lObR18hko7Yei0daDaOZh7GIm0NNkBDCC1dfAZKu2EodPWAWmnehgiIhIX9TBERCQuQz5hmNkvzGy/ma0/gX1PM7N1ZrbVzH5kMSvDm9nNZrbJzDaY2R39G/WJSURbzeybZrbbzFYHr8v6P/LjjjUh32mw/Wtm5mY2qv8iPnEJ+k6/ZWZrg+/zGTNL+opBCWrn98zszaCtj5jZiP6P/PglqK0fDX4XdZjZiY91DMStWGF+AecBpwLrT2Df14CzAAOeBD4UlF8A/AHIDj6PTnY7E9jWbwJfS3bbEt3OYNt44Gmiz/qMSnY7E/idDo+p80Xgp4O0nYuAjOD9d4HvJrudCWzrTGA68BxQcaKxDfkehru/QHR52KPMbIqZPWVmK8zsL2Y2o+t+ZlZK9B/WKx79Rn4FXBVs/jvgO+7eEpxjf2JbEZ8EtTV0EtjOfwP+AQjNwF8i2urudTFV8wlBexPUzmfcPRJUXQ6MS2wr4pOgtr7h7pv6GtuQTxg9WArc7O6nAV8DftJNnXJgV8znXUEZwDTgXDN71cyeN7PTExpt3/S1rQBfCLr1vzCzkYkLtU/61E4zuxLY7e5rEh1oP+jzd2pm3zazncDfALcmMNa+6I+/u50+TfR/5GHVn209YRn9ebDBwMyGAWcDD8Rcvs7urmo3ZZ3/E8sARgJnAqcD95vZSUHWD41+aut/At8KPn8L+D7Rf3yh0dd2mlke8A2ilzBCrZ++U9z9G8A3zOzrwBeA2/o51D7pr3YGx/oGEAF+3Z8x9pf+bGtfKWG8XxpwyN3nxxaaWTqwIvi4jOgvytgu7DigKni/C3g4SBCvmVkH0bleqhMZ+Anoc1vdfV/Mfj8Hfp/IgE9QX9s5BZgMrAn+wY4DVprZQnffm+DYj1d//P2NdS/wOCFLGPRTO83sBuBy4KKw/YcuRn9/pycu2QM8YXgBk4gZYAJeBj4avDdgXg/7vU60F9E5wHRZUP454Pbg/TRgJ8EzL8l+JaCtpTF1bgF+k+w2JqKdXeq8TUgGvRP0nU6NqXMz8GCy25igdl4KbARKkt22RLc1Zvtz9GHQO+l/MMl+AfcBe4A2oj2DzxD93+RTwJrgL9StPexbAawH3gJ+3JkUgCzgv4NtK4ELk93OBLb1v4B1wFqi/8spHaj2DGQ7u9QJTcJI0Hf6UFC+luh8ReWDtJ1bif5nbnXwSvrdYAls69XBsVqAfcDTJxKbnvQWEZG46C4pERGJixKGiIjERQlDRETiooQhIiJxUcIQEZG4KGHIoGZm9QN8vrvMbFY/Has9mDF2vZk9dqzZVM1shJn9fX+cW6Q7uq1WBjUzq3f3Yf14vAx/d8K6hIqN3czuATa7+7d7qT8J+L27zxmI+GToUQ9DhhwzKzGzh8zs9eB1TlC+0MxeNrNVwc/pQfknzewBM3sMeMbMzjez58zswWA9hV/HrDvwXOd6A2ZWH0zit8bMlpvZmKB8SvD5dTO7Pc5e0Cu8OxHiMDP7o5mttOjaB4uDOt8BpgS9ku8Fdf9HcJ61ZvZ/+/GPUYYgJQwZin4I/Ju7nw58GLgrKH8TOM/dFxCdofWfY/Y5C7jB3S8MPi8AvgzMAk4CzunmPPnAcnefB7wAfDbm/D8Mzn/MuX6COYMuIvokPUAzcLW7n0p07ZXvBwnrfwFvuft8d/8fZrYImAosBOYDp5nZecc6n0hPNPmgDEUXA7NiZv4cbmYFQCFwj5lNJTrLZ2bMPs+6e+waBa+5+y4AM1tNdO6fF7ucp5V3J2NcAVwSvD+Ld9fZuBe4s4c4c2OOvQJ4Nig34J+DX/4dRHseY7rZf1HwWhV8HkY0gbzQw/lEeqWEIUNRGnCWuzfFFprZvwN/dverg/GA52I2N3Q5RkvM+3a6/7fU5u8OEvZUpzdN7j7fzAqJJp7PAz8iukZFCXCau7eZ2dtATjf7G/Av7v6z4zyvSLd0SUqGomeIrvEAgJl1ThtdCOwO3n8ygedfTvRSGMCSY1V298NEl0r9mpllEo1zf5AsLgAmBlWPAAUxuz4NfDpYTwEzKzez0f3UBhmClDBksMszs10xr68Q/eVbEQwEbyQ6HT3AHcC/mNlLQHoCY/oy8BUzew0oBQ4fawd3X0V0ptIlRBf6qTCzSqK9jTeDOjXAS8FtuN9z92eIXvJ6xczWAQ/y3oQiclx0W63IAAtW8GtydzezJcB17r74WPuJJJvGMEQG3mnAj4M7mw4RsiVtRXqiHoaIiMRFYxgiIhIXJQwREYmLEoaIiMRFCUNEROKihCEiInFRwhARkbj8f6owc6WGyt9UAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>seq2seq_acc</th>\n",
       "      <th>bleu</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.452436</td>\n",
       "      <td>4.709918</td>\n",
       "      <td>0.412980</td>\n",
       "      <td>0.208454</td>\n",
       "      <td>01:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.137345</td>\n",
       "      <td>4.476718</td>\n",
       "      <td>0.422126</td>\n",
       "      <td>0.344813</td>\n",
       "      <td>00:57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.974048</td>\n",
       "      <td>3.824592</td>\n",
       "      <td>0.472997</td>\n",
       "      <td>0.377652</td>\n",
       "      <td>00:58</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.813645</td>\n",
       "      <td>3.864258</td>\n",
       "      <td>0.470798</td>\n",
       "      <td>0.389968</td>\n",
       "      <td>00:57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.818273</td>\n",
       "      <td>4.042902</td>\n",
       "      <td>0.456217</td>\n",
       "      <td>0.390355</td>\n",
       "      <td>00:56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.668895</td>\n",
       "      <td>3.635575</td>\n",
       "      <td>0.482699</td>\n",
       "      <td>0.411627</td>\n",
       "      <td>00:56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>1.620335</td>\n",
       "      <td>3.741779</td>\n",
       "      <td>0.474715</td>\n",
       "      <td>0.410962</td>\n",
       "      <td>00:56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>1.852314</td>\n",
       "      <td>3.721396</td>\n",
       "      <td>0.471986</td>\n",
       "      <td>0.402945</td>\n",
       "      <td>00:55</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(8, 3e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='152' class='' max='152', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [152/152 00:17<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "inputs, targets, outputs = get_predictions(learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos pour quelle raison demandez - vous aux émetteurs des renseignements qui n'ont pas à être fournis sur les reçus papier remis aux contribuables ?,\n",
       " Text xxbos why are your requiring xxunk to provide information that is not required to be on the paper receipts given to clients ?,\n",
       " Text xxbos why do you think to the information that the information that not be provided on the payment ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[700], targets[700], outputs[700]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos quels facteurs sont responsables des différences de concentrations des contaminants présents dans les poissons dans les cours d’eau et les lacs du nord ?,\n",
       " Text xxbos what factors are responsible for the differences in the level of contaminants found fish in northern rivers and lakes ?,\n",
       " Text xxbos what factors are the of the levels of contaminants in in in water in water and water in the north ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[701], targets[701], outputs[701]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos quels sont les avantages et les inconvénients à ce jour de cette approche ?,\n",
       " Text xxbos what are the advantages and disadvantages of this approach to date ?,\n",
       " Text xxbos what are the advantages and disadvantages of this approach ?)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[4002], targets[4002], outputs[4002]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
